From 5ef80c901f4d208f2ebe6c00f2148aea1fae7f3b Mon Sep 17 00:00:00 2001 From: Christopher Landschoot Date: Wed, 22 May 2024 16:24:41 -0500 Subject: [PATCH] Corrected RMS & metric calculations for digital silence. --- torch_log_wmse_audio_quality/metric.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torch_log_wmse_audio_quality/metric.py b/torch_log_wmse_audio_quality/metric.py index be1c765..41074d0 100644 --- a/torch_log_wmse_audio_quality/metric.py +++ b/torch_log_wmse_audio_quality/metric.py @@ -62,10 +62,6 @@ def forward(self, unprocessed_audio: Tensor, processed_audio: Tensor, target_aud input_rms = calculate_rms(self.filters(unprocessed_audio.unsqueeze(1))) # unsqueeze to add "stem" dimension - # Avoid log(0) - if input_rms.sum() == 0: - return torch.log(torch.tensor(EPS)) * SCALER - # Calculate the logWMSE values = self._calculate_log_wmse( input_rms, @@ -98,6 +94,11 @@ def _calculate_log_wmse( Returns: Tensor: The logWMSE between the processed audio and target audio. """ + + # Add EPS if input_rms is 0 (silence) to avoid NaNs + if input_rms.sum() == 0: + input_rms = torch.ones_like(input_rms) * ERROR_TOLERANCE_THRESHOLD + # Calculate the scaling factor based on the input RMS scaling_factor = 1 / input_rms