Skip to content

Commit

Permalink
Added tests for HumanHearingSensitivityFilter and confirm correct cir…
Browse files Browse the repository at this point in the history
…cular shift
  • Loading branch information
crlandsc committed May 20, 2024
1 parent 39c48b9 commit 944c9cf
Showing 1 changed file with 44 additions and 1 deletion.
45 changes: 44 additions & 1 deletion tests/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
sys.path.append("/Users/chris/Desktop/Whitebalance/torch-log-wmse-audio-quality")
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch_log_wmse_audio_quality import LogWMSE
from torch_log_wmse_audio_quality.utils import calculate_rms
from torch_log_wmse_audio_quality.freq_weighting_filter import prepare_impulse_response_fft, HumanHearingSensitivityFilter

class TestLogWMSELoss(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -68,7 +70,7 @@ def test_forward(self):

def test_logWMSE_metric_comparison(self):
"""For comparison with the original logWMSE metric implementation in numpy."""
audio_lengths = [0.01, 0.1, 0.5, 1.0, 2.0, 10.0, 50.0] # Different audio lengths
audio_lengths = [0.01, 0.1, 0.5, 1.0, 2.0, 10.0] # Different audio lengths
for i, audio_length in enumerate(audio_lengths):
log_wmse_loss = LogWMSE(audio_length=audio_length, sample_rate=44100)
for j in range(3):
Expand All @@ -89,5 +91,46 @@ def test_logWMSE_metric_comparison(self):

print(f"Test {i}, Subtest {j}, Audio Length: {audio_length}, Loss: {loss}, Seed: {(i+1)*(j+1)}")

class TestFreqWeightingFilter(unittest.TestCase):
def setUp(self):
# Example audio data, replace with actual audio loading if needed
self.plot_output = False
self.sample_rate = 44100
self.audio_length = 1.7516936
t = np.arange(0, int(self.audio_length*self.sample_rate)) / self.sample_rate
self.audio = torch.tensor(0.5 * np.sin(2 * np.pi * 440 * t)) # A simple 440 Hz sine wave
self.audio = self.audio[None, None, None, :]

def test_prepare_impulse_response_fft(self):
ir = torch.rand(512) # Example impulse response
fft_size = 1024
ir_fft = prepare_impulse_response_fft(ir, fft_size)
self.assertEqual(ir_fft.shape[-1], fft_size//2+1)

def test_HumanHearingSensitivityFilter(self):
plot_upper_bound = 500
hhs_filter = HumanHearingSensitivityFilter(audio_length=self.audio_length, sample_rate=self.sample_rate)
# Add zeros at index 50-100 to demonstrate time alignment
self.audio[:, :, :, 50:100] = 0
self.audio[:, :, :, 101:125] = 0.5
self.audio[:, :, :, 126:150] = -0.5

filtered_audio = hhs_filter(self.audio)

# Plot the first 1000 samples before and after filtering
if self.plot_output:
fig, axs = plt.subplots(2, 1, figsize=(12, 8))
axs[0].plot(self.audio.squeeze()[:plot_upper_bound])
axs[0].set_title(f'Original Audio (First {plot_upper_bound} Samples)')
axs[0].set_ylim(-1, 1)
axs[1].plot(filtered_audio.squeeze()[:plot_upper_bound])
axs[1].set_title(f'Filtered Audio (First {plot_upper_bound} Samples)')
axs[1].set_ylim(-1, 1)
plt.tight_layout()
plt.show()

self.assertEqual(filtered_audio.shape, self.audio.shape)


if __name__ == "__main__":
unittest.main()

0 comments on commit 944c9cf

Please sign in to comment.