diff --git a/sonnet/src/metrics.py b/sonnet/src/metrics.py index fb871a0..abada45 100644 --- a/sonnet/src/metrics.py +++ b/sonnet/src/metrics.py @@ -62,6 +62,7 @@ def initialize(self, value: tf.Tensor): def update(self, value: tf.Tensor): """See base class.""" self.initialize(value) + assert self.sum is not None self.sum.assign_add(value) @property @@ -71,6 +72,8 @@ def value(self) -> tf.Tensor: def reset(self): """See base class.""" + if self.sum is None: + raise ValueError("Metric is not initialized. Call `initialize` first.") self.sum.assign(tf.zeros_like(self.sum)) @@ -90,15 +93,24 @@ def initialize(self, value: tf.Tensor): def update(self, value: tf.Tensor): """See base class.""" self.initialize(value) + assert self.sum is not None self.sum.assign_add(value) self.count.assign_add(1) + @property + def _checked_sum(self) -> tf.Variable: + if self.sum is None: + raise ValueError("Metric is not initialized. Call `initialize` first.") + return self.sum + @property def value(self) -> tf.Tensor: """See base class.""" # TODO(cjfj): Assert summed type is floating-point? - return self.sum / tf.cast(self.count, dtype=self.sum.dtype) + return self._checked_sum / tf.cast( + self.count, dtype=self._checked_sum.dtype + ) def reset(self): - self.sum.assign(tf.zeros_like(self.sum)) + self._checked_sum.assign(tf.zeros_like(self._checked_sum)) self.count.assign(0)