diff --git a/tests/test_metric.py b/tests/test_metric.py index 714b789..54408cd 100644 --- a/tests/test_metric.py +++ b/tests/test_metric.py @@ -3,8 +3,10 @@ sys.path.append("/Users/chris/Desktop/Whitebalance/torch-log-wmse-audio-quality") import torch import numpy as np +import matplotlib.pyplot as plt from torch_log_wmse_audio_quality import LogWMSE from torch_log_wmse_audio_quality.utils import calculate_rms +from torch_log_wmse_audio_quality.freq_weighting_filter import prepare_impulse_response_fft, HumanHearingSensitivityFilter class TestLogWMSELoss(unittest.TestCase): def setUp(self): @@ -68,7 +70,7 @@ def test_forward(self): def test_logWMSE_metric_comparison(self): """For comparison with the original logWMSE metric implementation in numpy.""" - audio_lengths = [0.01, 0.1, 0.5, 1.0, 2.0, 10.0, 50.0] # Different audio lengths + audio_lengths = [0.01, 0.1, 0.5, 1.0, 2.0, 10.0] # Different audio lengths for i, audio_length in enumerate(audio_lengths): log_wmse_loss = LogWMSE(audio_length=audio_length, sample_rate=44100) for j in range(3): @@ -89,5 +91,46 @@ def test_logWMSE_metric_comparison(self): print(f"Test {i}, Subtest {j}, Audio Length: {audio_length}, Loss: {loss}, Seed: {(i+1)*(j+1)}") +class TestFreqWeightingFilter(unittest.TestCase): + def setUp(self): + # Example audio data, replace with actual audio loading if needed + self.plot_output = False + self.sample_rate = 44100 + self.audio_length = 1.7516936 + t = np.arange(0, int(self.audio_length*self.sample_rate)) / self.sample_rate + self.audio = torch.tensor(0.5 * np.sin(2 * np.pi * 440 * t)) # A simple 440 Hz sine wave + self.audio = self.audio[None, None, None, :] + + def test_prepare_impulse_response_fft(self): + ir = torch.rand(512) # Example impulse response + fft_size = 1024 + ir_fft = prepare_impulse_response_fft(ir, fft_size) + self.assertEqual(ir_fft.shape[-1], fft_size//2+1) + + def test_HumanHearingSensitivityFilter(self): + plot_upper_bound = 500 + hhs_filter = HumanHearingSensitivityFilter(audio_length=self.audio_length, sample_rate=self.sample_rate) + # Add zeros at index 50-100 to demonstrate time alignment + self.audio[:, :, :, 50:100] = 0 + self.audio[:, :, :, 101:125] = 0.5 + self.audio[:, :, :, 126:150] = -0.5 + + filtered_audio = hhs_filter(self.audio) + + # Plot the first 1000 samples before and after filtering + if self.plot_output: + fig, axs = plt.subplots(2, 1, figsize=(12, 8)) + axs[0].plot(self.audio.squeeze()[:plot_upper_bound]) + axs[0].set_title(f'Original Audio (First {plot_upper_bound} Samples)') + axs[0].set_ylim(-1, 1) + axs[1].plot(filtered_audio.squeeze()[:plot_upper_bound]) + axs[1].set_title(f'Filtered Audio (First {plot_upper_bound} Samples)') + axs[1].set_ylim(-1, 1) + plt.tight_layout() + plt.show() + + self.assertEqual(filtered_audio.shape, self.audio.shape) + + if __name__ == "__main__": unittest.main() \ No newline at end of file