From cdef4d43fb1a2c6c4349daa5080e4e8731c34569 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Thu, 19 Sep 2024 10:46:15 -0700 Subject: [PATCH] Use log1p(x) instead of log(1+x) (#1286) This function is more accurate than torch.log() for small values of input - https://pytorch.org/docs/stable/generated/torch.log1p.html Found with https://github.com/pytorch-labs/torchfix/ --- mnist_forward_forward/main.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mnist_forward_forward/main.py b/mnist_forward_forward/main.py index f137dee48a..a175126067 100644 --- a/mnist_forward_forward/main.py +++ b/mnist_forward_forward/main.py @@ -72,9 +72,8 @@ def train(self, x_pos, x_neg): for i in range(self.num_epochs): g_pos = self.forward(x_pos).pow(2).mean(1) g_neg = self.forward(x_neg).pow(2).mean(1) - loss = torch.log( - 1 - + torch.exp( + loss = torch.log1p( + torch.exp( torch.cat([-g_pos + self.threshold, g_neg - self.threshold]) ) ).mean()