Skip to content

Commit

Permalink
Corrected RMS & metric calculations for digital silence.
Browse files Browse the repository at this point in the history
  • Loading branch information
crlandsc committed May 22, 2024
1 parent 607c775 commit 5ef80c9
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions torch_log_wmse_audio_quality/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@ def forward(self, unprocessed_audio: Tensor, processed_audio: Tensor, target_aud

input_rms = calculate_rms(self.filters(unprocessed_audio.unsqueeze(1))) # unsqueeze to add "stem" dimension

# Avoid log(0)
if input_rms.sum() == 0:
return torch.log(torch.tensor(EPS)) * SCALER

# Calculate the logWMSE
values = self._calculate_log_wmse(
input_rms,
Expand Down Expand Up @@ -98,6 +94,11 @@ def _calculate_log_wmse(
Returns:
Tensor: The logWMSE between the processed audio and target audio.
"""

# Add EPS if input_rms is 0 (silence) to avoid NaNs
if input_rms.sum() == 0:
input_rms = torch.ones_like(input_rms) * ERROR_TOLERANCE_THRESHOLD

# Calculate the scaling factor based on the input RMS
scaling_factor = 1 / input_rms

Expand Down

0 comments on commit 5ef80c9

Please sign in to comment.