From 5ad4abf1e98e0b43eb53132dcc678d07b1b7e949 Mon Sep 17 00:00:00 2001 From: Christopher Landschoot Date: Wed, 22 May 2024 18:27:15 -0500 Subject: [PATCH] Changed scaling factor so RMS doesn't need to = 0, rather just be lower than the error threshold to replace with min value. --- tests/test_metric.py | 2 +- torch_log_wmse_audio_quality/metric.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_metric.py b/tests/test_metric.py index 6d4b036..5f769c3 100644 --- a/tests/test_metric.py +++ b/tests/test_metric.py @@ -89,7 +89,7 @@ def test_forward_silence(self): # Generate random inputs (scale between -1 and 1) audio_lengths_samples = int(audio_length * sample_rate) - unprocessed_audio = torch.zeros(batch, audio_channels, audio_lengths_samples) + unprocessed_audio = torch.rand(batch, audio_channels, audio_lengths_samples) * convert_decibels_to_amplitude_ratio(-75) processed_audio = torch.rand(batch, audio_stems, audio_channels, audio_lengths_samples) * convert_decibels_to_amplitude_ratio(-60) target_audio = torch.zeros(batch, audio_stems, audio_channels, audio_lengths_samples) diff --git a/torch_log_wmse_audio_quality/metric.py b/torch_log_wmse_audio_quality/metric.py index 41074d0..bcbc665 100644 --- a/torch_log_wmse_audio_quality/metric.py +++ b/torch_log_wmse_audio_quality/metric.py @@ -95,8 +95,8 @@ def _calculate_log_wmse( Tensor: The logWMSE between the processed audio and target audio. """ - # Add EPS if input_rms is 0 (silence) to avoid NaNs - if input_rms.sum() == 0: + # Add EPS if input_rms is 0 (silence), or close to it, to avoid NaNs + if input_rms.sum() < ERROR_TOLERANCE_THRESHOLD: input_rms = torch.ones_like(input_rms) * ERROR_TOLERANCE_THRESHOLD # Calculate the scaling factor based on the input RMS