Skip to content

Commit

Permalink
Resolves potential Nones caught by tytype.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675998642
  • Loading branch information
Sonnet Contributor authored and copybara-github committed Sep 18, 2024
1 parent 6d59725 commit 453d05c
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions sonnet/src/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def initialize(self, value: tf.Tensor):
def update(self, value: tf.Tensor):
"""See base class."""
self.initialize(value)
assert self.sum is not None
self.sum.assign_add(value)

@property
Expand All @@ -71,6 +72,8 @@ def value(self) -> tf.Tensor:

def reset(self):
"""See base class."""
if self.sum is None:
raise ValueError("Metric is not initialized. Call `initialize` first.")
self.sum.assign(tf.zeros_like(self.sum))


Expand All @@ -90,15 +93,24 @@ def initialize(self, value: tf.Tensor):
def update(self, value: tf.Tensor):
"""See base class."""
self.initialize(value)
assert self.sum is not None
self.sum.assign_add(value)
self.count.assign_add(1)

@property
def _checked_sum(self) -> tf.Variable:
if self.sum is None:
raise ValueError("Metric is not initialized. Call `initialize` first.")
return self.sum

@property
def value(self) -> tf.Tensor:
"""See base class."""
# TODO(cjfj): Assert summed type is floating-point?
return self.sum / tf.cast(self.count, dtype=self.sum.dtype)
return self._checked_sum / tf.cast(
self.count, dtype=self._checked_sum.dtype
)

def reset(self):
self.sum.assign(tf.zeros_like(self.sum))
self._checked_sum.assign(tf.zeros_like(self._checked_sum))
self.count.assign(0)

0 comments on commit 453d05c

Please sign in to comment.