From c41c5241572b89e2b3055ec23281c4d409f09f87 Mon Sep 17 00:00:00 2001 From: Markus Hofbauer Date: Mon, 28 Feb 2022 17:44:50 +0100 Subject: [PATCH] add wav example --- .github/workflows/ci.yml | 18 ++---------- .pre-commit-config.yaml | 12 ++++++-- Makefile | 7 +++++ README.md | 1 + docs/get-started.md | 9 ++++-- examples/evaluate_wav.py | 50 +++++++++++++++++++++++++++++++++ examples/white_noise.py | 21 ++++++++++++++ test/metrics/test_snr.py | 15 +++++++++- test/metrics/test_spqi.py | 24 +++++++++++++++- test/metrics/test_stsim.py | 52 +++++++++++++++++++++++++++++------ test/metrics/test_vibromaf.py | 3 +- test/signal/test_spectrum.py | 11 ++++---- test/signal/test_transform.py | 27 ++++++++++++++++++ vibromaf/metrics/snr.py | 2 ++ vibromaf/metrics/spqi.py | 9 ++++-- vibromaf/metrics/stsim.py | 23 ++++++++-------- vibromaf/signal/spectrum.py | 8 ++++-- vibromaf/signal/transform.py | 20 ++++++++++++-- vibromaf/util/matlab.py | 34 ++++++++++++++++++++++- 19 files changed, 288 insertions(+), 58 deletions(-) create mode 100644 examples/evaluate_wav.py create mode 100644 examples/white_noise.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0364da6..fc94917 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,25 +22,11 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements-dev.txt - - name: Check - run: pre-commit run --all-files - name: Test run: make test - name: Build Package run: make package - name: Check Package run: make check_dist - - # super-lint: - # runs-on: ubuntu-latest - # steps: - # - uses: actions/checkout@v2 - # - name: Lint Code Base - # uses: docker://github/super-linter:v4.8.5 - # env: - # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - # VALIDATE_BASH: false # Already configured in pre-commit - # VALIDATE_DOCKERFILE_HADOLINT: false - # VALIDATE_PYTHON_FLAKE8: false # Already configured in pre-commit - # VALIDATE_PYTHON_PYLINT: false # Already configured in pre-commit - # VALIDATE_PYTHON_ISORT: false # Already configured in pre-commit + - name: Smoke Test + run: make smoke_test diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9ccf672..dfac2e9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,10 @@ default_stages: [commit] + +ci: + autoupdate_commit_msg: "chore(deps): pre-commit.ci autoupdate" + autoupdate_schedule: "monthly" + autofix_commit_msg: "style: pre-commit.ci fixes" + repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.1.0 @@ -6,14 +12,14 @@ repos: - id: trailing-whitespace - id: end-of-file-fixer - repo: https://github.com/psf/black - rev: 21.12b0 + rev: 22.1.0 hooks: - id: black language_version: python3 - id: black-jupyter language_version: python3 - repo: https://github.com/asottile/blacken-docs - rev: v1.12.0 + rev: v1.12.1 hooks: - id: blacken-docs - repo: https://github.com/PyCQA/flake8 @@ -35,6 +41,6 @@ repos: - id: pyupgrade args: [--py37-plus] - repo: https://github.com/shellcheck-py/shellcheck-py - rev: v0.8.0.3 + rev: v0.8.0.4 hooks: - id: shellcheck diff --git a/Makefile b/Makefile index 56a527c..c5cf925 100644 --- a/Makefile +++ b/Makefile @@ -31,3 +31,10 @@ deploy: package check_dist clean: rm -rf site htmlcov dist vibromaf.egg-info + +smoke_test: + mv dist/vibromaf-*.tar.gz dist/vibromaf.tar.gz || true + pip3 install -U dist/vibromaf.tar.gz + cd examples && python3 white_noise.py + +smoke_test_clean: clean package smoke_test diff --git a/README.md b/README.md index ee65903..304e7db 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ # VibroMAF - Vibrotactile Multi-Method Assessment Fusion +[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/hofbi/vibromaf/main.svg)](https://results.pre-commit.ci/latest/github/hofbi/vibromaf/main) [![Actions Status](https://github.com/hofbi/vibromaf/workflows/CI/badge.svg)](https://github.com/hofbi/vibromaf) [![Actions Status](https://github.com/hofbi/vibromaf/workflows/Docs/badge.svg)](https://hofbi.github.io/vibromaf) [![Actions Status](https://github.com/hofbi/vibromaf/workflows/CodeQL/badge.svg)](https://github.com/hofbi/vibromaf) diff --git a/docs/get-started.md b/docs/get-started.md index 9a6eb6c..09b07d0 100644 --- a/docs/get-started.md +++ b/docs/get-started.md @@ -34,7 +34,7 @@ from vibromaf.metrics.stsim import st_sim st_sim_score = st_sim(sample_distorted_signal, sample_reference_signal) -print(st_sim_score) # Should be around 0.85 +print(st_sim_score) # Should be around 0.81 ``` Find further details how to use this metric at [ST-SIM](metrics/stsim.md). @@ -48,7 +48,12 @@ from vibromaf.metrics.spqi import spqi spqi_score = spqi(sample_distorted_signal, sample_reference_signal) -print(spqi_score) # Should be around 1.0 +print(spqi_score) # Should be around 0.43 ``` Find further details how to use this metric at [SPQI](metrics/spqi.md). + +## Full Example + +Find the full example in `examples/white_noise.py`. +All examples available can be found in the `examples` folder. diff --git a/examples/evaluate_wav.py b/examples/evaluate_wav.py new file mode 100644 index 0000000..847d0f4 --- /dev/null +++ b/examples/evaluate_wav.py @@ -0,0 +1,50 @@ +"""Example to evaluate WAV files""" + +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, FileType + +from scipy.io import wavfile + +from vibromaf.metrics.snr import snr +from vibromaf.metrics.spqi import spqi +from vibromaf.metrics.stsim import st_sim + + +def parse_arguments(): + """Parse command line arguments""" + parser = ArgumentParser( + description=__doc__, + formatter_class=ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "distorted", + type=FileType("r"), + help="Distorted .wav file", + ) + parser.add_argument( + "reference", + type=FileType("r"), + help="Undistorted reference .wav file", + ) + return parser.parse_args() + + +def main(): + """main""" + args = parse_arguments() + + distorted_signal = wavfile.read(args.distorted.name)[1] + reference_signal = wavfile.read(args.reference.name)[1] + + # Calculate metric scores + snr_score = snr(distorted_signal, reference_signal) + st_sim_score = st_sim(distorted_signal, reference_signal) + spqi_score = spqi(distorted_signal, reference_signal) + + # Print individual metric scores + print(f"SNR score: {snr_score}") + print(f"ST-SIM score: {st_sim_score}") + print(f"SPQI score: {spqi_score}") + + +if __name__ == "__main__": + main() diff --git a/examples/white_noise.py b/examples/white_noise.py new file mode 100644 index 0000000..768849b --- /dev/null +++ b/examples/white_noise.py @@ -0,0 +1,21 @@ +"""Simple get started example with white noise signals""" + +import numpy as np + +from vibromaf.metrics.snr import snr +from vibromaf.metrics.spqi import spqi +from vibromaf.metrics.stsim import st_sim + +# Define sample signals +sample_reference_signal = np.ones(1000) * 1000 + np.random.randn(1000) +sample_distorted_signal = sample_reference_signal + np.random.randn(1000) + +# Calculate metric scores +snr_score = snr(sample_distorted_signal, sample_reference_signal) +st_sim_score = st_sim(sample_distorted_signal, sample_reference_signal) +spqi_score = spqi(sample_distorted_signal, sample_reference_signal) + +# Print individual metric scores +print(f"SNR score: {snr_score}") +print(f"ST-SIM score: {st_sim_score}") +print(f"SPQI score: {spqi_score}") diff --git a/test/metrics/test_snr.py b/test/metrics/test_snr.py index 2772be8..292b1cc 100644 --- a/test/metrics/test_snr.py +++ b/test/metrics/test_snr.py @@ -33,7 +33,20 @@ def test_snr__sample_signals(self): result = snr(sample_distorted_signal, sample_reference_signal) - self.assertAlmostEqual(60, result, delta=1.0) + self.assertAlmostEqual(60, result, delta=2.0) + + def test_snr__dist_larger_than_ref__dist_should_be_truncated_and_warn(self): + signal = np.array([0, 1]) + dist = np.array([0, 1, 2]) + with self.assertWarnsRegex(RuntimeWarning, r"Truncating distorted signal"): + result = snr(dist, signal) + self.assertEqual(np.Inf, result) + + def test_snr__dist_shorter_than_ref__should_throw(self): + signal = np.array([0, 1, 2]) + dist = np.array([0, 1]) + with self.assertRaisesRegex(ValueError, r"Distorted .* must not be shorter"): + snr(dist, signal) def test_nsnr__snr_larger_than_max__should_be_1(self): signal = np.array([0, 1]) diff --git a/test/metrics/test_spqi.py b/test/metrics/test_spqi.py index f1e246d..174258d 100644 --- a/test/metrics/test_spqi.py +++ b/test/metrics/test_spqi.py @@ -21,4 +21,26 @@ def test_spqi_wrapper__sample_signals(self): result = spqi(sample_distorted_signal, sample_reference_signal) - self.assertAlmostEqual(1.0, result) + self.assertGreaterEqual(1, result) + self.assertGreaterEqual(result, 0) + + def test_spqi__truncated_signals_identical__dist_should_be_truncated(self): + signal = np.linspace(0, 1, 2800) + dist = np.append(np.linspace(0, 1, 2800), np.ones(300)) + with self.assertWarnsRegex(RuntimeWarning, r"Truncating distorted signal"): + result = spqi(dist, signal) + self.assertEqual(1, result) + + def test_spqi__dist_larger_than_ref__dist_should_be_truncated(self): + signal = np.linspace(0, 1, 2800) + dist = np.append(np.ones(300), np.linspace(0, 1, 2800)) + with self.assertWarnsRegex(RuntimeWarning, r"Truncating distorted signal"): + result = spqi(dist, signal) + self.assertGreaterEqual(1, result) + self.assertGreaterEqual(result, 0) + + def test_spqi__dist_shorter_than_ref__should_throw(self): + signal = np.append(np.linspace(0, 1, 2800), np.ones(300)) + dist = np.linspace(0, 1, 2800) + with self.assertRaisesRegex(ValueError, r"Distorted .* must not be shorter"): + spqi(dist, signal) diff --git a/test/metrics/test_stsim.py b/test/metrics/test_stsim.py index 6c3803a..4d6b382 100644 --- a/test/metrics/test_stsim.py +++ b/test/metrics/test_stsim.py @@ -21,15 +21,51 @@ def test_st_sim_wrapper__sample_signals(self): result = st_sim(sample_distorted_signal, sample_reference_signal) - self.assertAlmostEqual(0.85, result, delta=0.1) + self.assertGreaterEqual(1, result) + self.assertGreaterEqual(result, 0) - def test_compute_block_sim__one_block_zero__zero(self): - ref_block = np.ones(4) - dist_block = np.zeros(4) - result = STSIM.compute_block_sim(ref_block, dist_block) + def test_st_sim__truncated_signals_identical__dist_should_be_truncated(self): + signal = np.linspace(0, 1, 2800) + dist = np.append(np.linspace(0, 1, 2800), np.ones(300)) + with self.assertWarnsRegex(RuntimeWarning, r"Truncating distorted signal"): + result = st_sim(dist, signal) + self.assertEqual(1, result) + + def test_st_sim__dist_larger_than_ref__dist_should_be_truncated(self): + signal = np.linspace(0, 1, 2800) + dist = np.append(np.ones(300), np.linspace(0, 1, 2800)) + with self.assertWarnsRegex(RuntimeWarning, r"Truncating distorted signal"): + result = st_sim(dist, signal) + self.assertGreaterEqual(1, result) + self.assertGreaterEqual(result, 0) + + def test_st_sim__dist_shorter_than_ref__should_throw(self): + signal = np.append(np.linspace(0, 1, 2800), np.ones(300)) + dist = np.linspace(0, 1, 2800) + with self.assertRaisesRegex(ValueError, r"Distorted .* must not be shorter"): + st_sim(dist, signal) + + def test_compute_sim__one_block_zero__zero(self): + ref_block = np.ones((4, 2)) + dist_block = np.zeros((4, 2)) + result = STSIM.compute_sim(ref_block, dist_block) self.assertEqual(0, result) - def test_compute_block_sim__blocks_identical__one(self): - block = np.array([1, 2, 3, 4]) - result = STSIM.compute_block_sim(block, block) + def test_compute_sim__blocks_identical__one(self): + block = np.ones((4, 3)) + result = STSIM.compute_sim(block, block) self.assertEqual(1, result) + + def test_compute_sim__values_larger_than_one_possible(self): + ref_block = np.array([[0.5] * 4]) + dist_block = np.array([[0.5, 0.5, 0.5, 1]]) + result = STSIM.compute_sim(ref_block, dist_block) + self.assertEqual(1.25, result) + + def test_st_sim_init__eta_grater_one__should_throw(self): + with self.assertRaisesRegex(ValueError, "Eta must be between 0 and 1."): + STSIM(eta=1.1) + + def test_st_sim_init__eta_negative__should_throw(self): + with self.assertRaisesRegex(ValueError, "Eta must be between 0 and 1."): + STSIM(eta=-0.1) diff --git a/test/metrics/test_vibromaf.py b/test/metrics/test_vibromaf.py index 1685763..ab7f7e8 100644 --- a/test/metrics/test_vibromaf.py +++ b/test/metrics/test_vibromaf.py @@ -27,4 +27,5 @@ def test_vibromaf_wrapper__sample_signals__some_output_value(self): sample_distorted_signal, sample_reference_signal, Path("test-model.pickle") ) - self.assertAlmostEqual(0.16, result, delta=0.1) + self.assertGreaterEqual(1, result) + self.assertGreaterEqual(result, 0) diff --git a/test/signal/test_spectrum.py b/test/signal/test_spectrum.py index f443529..ca0d239 100644 --- a/test/signal/test_spectrum.py +++ b/test/signal/test_spectrum.py @@ -34,17 +34,18 @@ def test_mag2db(self): def test_compute_normalized_spectral_difference__same_signals_should_be_minus_inf( self, ): - signal = np.ones(10) + signal = np.ones((10, 2)) result = compute_normalized_spectral_difference(signal, signal) - self.assertEqual(-np.inf, result) + self.assertListEqual([-np.inf] * 10, list(result)) def test_compute_normalized_spectral_difference__different_signals_should_be_positive( self, ): - signal_one = np.ones(10) - signal_two = np.zeros(10) + signal_one = np.ones((2, 10)) + signal_two = np.zeros((2, 10)) result = compute_normalized_spectral_difference(signal_one, signal_two) - self.assertAlmostEqual(1.542, result, delta=0.001) + self.assertGreaterEqual(0, result[0]) + self.assertEqual(2, result.size) def test_compute_spectral_support__zeros_array__array_with_0p5(self): spectrum = np.zeros((2, 4)) diff --git a/test/signal/test_transform.py b/test/signal/test_transform.py index 91bac0b..6d5acd4 100644 --- a/test/signal/test_transform.py +++ b/test/signal/test_transform.py @@ -11,6 +11,7 @@ compute_block_dct, compute_block_dft, cut_off_strategy, + preprocess_input_signal, zero_padding_strategy, ) @@ -87,6 +88,24 @@ def test_zero_padding_strategy__signal_equal_to_block_length__unchanged( result = zero_padding_strategy(input_signal, 3) self.assertListEqual([1, 2, 3], list(result)) + def test_preprocess_input_signal__dist_larger_than_ref__dist_should_be_truncated_and_warn( + self, + ): + signal = np.array([0, 1]) + dist = np.array([0, 1, 2]) + with self.assertWarnsRegex( + RuntimeWarning, + r"Truncating distorted signal .* since longer than reference", + ): + result = preprocess_input_signal(dist, signal) + self.assertListEqual(list(signal), list(result)) + + def test_preprocess_input_signal__dist_shorter_than_ref__should_throw(self): + signal = np.array([0, 1, 2]) + dist = np.array([0, 1]) + with self.assertRaisesRegex(ValueError, r"Distorted .* must not be shorter"): + preprocess_input_signal(dist, signal) + class BlockBuilderTest(unittest.TestCase): """Block Builder Test""" @@ -131,6 +150,14 @@ def test_divide_and_normalize__periodic_array__array_divided_by_two_and_shifted( result = unit.divide_and_normalize(input_signal) self.assertListEqual([-1, 1, -1, 1], list(result[0])) + def test_divide_and_normalize__multiple_blocks__correct_reshaped(self): + input_signal = np.array([-2, 2, -2, 2, -2, 2, -2, 2]) + unit = BlockBuilder(2) + result = unit.divide_and_normalize(input_signal) + self.assertTrue( + np.array_equal(np.array([[-1, 1], [-1, 1], [-1, 1], [-1, 1]]), result) + ) + class PerceptualSpectrumBuilderTest(unittest.TestCase): """Perceptual Spectrum Builder Test""" diff --git a/vibromaf/metrics/snr.py b/vibromaf/metrics/snr.py index ecad264..fad190b 100644 --- a/vibromaf/metrics/snr.py +++ b/vibromaf/metrics/snr.py @@ -3,6 +3,7 @@ import numpy as np from vibromaf.signal.spectrum import pow2db, signal_energy +from vibromaf.signal.transform import preprocess_input_signal def snr(distorted: np.array, reference: np.array) -> float: @@ -17,6 +18,7 @@ def snr(distorted: np.array, reference: np.array) -> float: ------- * `float` The SNR. """ + distorted = preprocess_input_signal(distorted, reference) return pow2db(signal_energy(reference) / signal_energy(distorted - reference)) diff --git a/vibromaf/metrics/spqi.py b/vibromaf/metrics/spqi.py index 1206a47..5daa0c4 100644 --- a/vibromaf/metrics/spqi.py +++ b/vibromaf/metrics/spqi.py @@ -5,7 +5,7 @@ import numpy as np from vibromaf.signal.spectrum import compute_normalized_spectral_difference -from vibromaf.signal.transform import PerceptualSpectrumBuilder +from vibromaf.signal.transform import PerceptualSpectrumBuilder, preprocess_input_signal def spqi( @@ -37,6 +37,7 @@ class SPQI: perceptual_spectrum_builder = PerceptualSpectrumBuilder() def calculate(self, distorted: np.array, reference: np.array) -> float: + distorted = preprocess_input_signal(distorted, reference) if np.array_equal(distorted, reference): return 1 @@ -53,9 +54,11 @@ def calculate(self, distorted: np.array, reference: np.array) -> float: block_spqi_scores = self.__compute_block_spqi(norm_perceptual_difference) - return np.mean(block_spqi_scores) + return float(np.mean(block_spqi_scores)) - def __compute_block_spqi(self, normalized_perceptual_difference: np.array) -> float: + def __compute_block_spqi( + self, normalized_perceptual_difference: np.array + ) -> np.array: return ( 1 - np.tanh(self.eta * normalized_perceptual_difference - self.threshold) ) / 2 diff --git a/vibromaf/metrics/stsim.py b/vibromaf/metrics/stsim.py index 1a04917..f6efe6d 100644 --- a/vibromaf/metrics/stsim.py +++ b/vibromaf/metrics/stsim.py @@ -5,7 +5,7 @@ import numpy as np from vibromaf.signal.spectrum import compute_spectral_support -from vibromaf.signal.transform import PerceptualSpectrumBuilder +from vibromaf.signal.transform import PerceptualSpectrumBuilder, preprocess_input_signal def st_sim(distorted: np.array, reference: np.array, eta: float = 2 / 3) -> float: @@ -33,6 +33,7 @@ class STSIM: perceptual_spectrum_builder = PerceptualSpectrumBuilder() def calculate(self, distorted: np.array, reference: np.array) -> float: + distorted = preprocess_input_signal(distorted, reference) if np.array_equal(distorted, reference): return 1 @@ -58,17 +59,15 @@ def calculate(self, distorted: np.array, reference: np.array) -> float: return pow(temporal_sim, self.eta) * pow(spectral_sim, 1 - self.eta) - @staticmethod - def compute_block_sim( - reference_block: np.array, distorted_block: np.array - ) -> float: - return np.sum(reference_block * distorted_block) / np.sum( - np.power(reference_block, 2) - ) - @staticmethod def compute_sim(reference: np.array, distorted: np.array) -> float: - block_sim = np.apply_along_axis( - STSIM.compute_block_sim, 1, reference, distorted + return float( + np.mean( + np.sum(reference * distorted, axis=1) + / np.sum(np.power(reference, 2), axis=1) + ) ) - return np.mean(block_sim) + + def __post_init__(self): + if not 0.0 < self.eta < 1.0: + raise ValueError("Eta must be between 0 and 1.") diff --git a/vibromaf/signal/spectrum.py b/vibromaf/signal/spectrum.py index 8401289..d38abe4 100644 --- a/vibromaf/signal/spectrum.py +++ b/vibromaf/signal/spectrum.py @@ -36,9 +36,11 @@ def signal_energy(signal: np.array) -> np.array: def compute_normalized_spectral_difference( reference_spectrum: np.array, distorted_spectrum: np.array ) -> np.array: - """Compute the normalized difference of two spectrums""" - difference = np.sum(np.abs(db2pow(reference_spectrum) - db2pow(distorted_spectrum))) - return pow2db(difference / np.sum(np.abs(db2pow(difference)))) + """Compute the normalized difference of two spectra""" + difference = np.sum( + np.abs(db2pow(reference_spectrum) - db2pow(distorted_spectrum)), axis=1 + ) + return pow2db(difference / np.sum(np.abs(db2pow(reference_spectrum)), axis=1)) def compute_spectral_support(spectrum: np.array, scale: float = 12) -> np.array: diff --git a/vibromaf/signal/transform.py b/vibromaf/signal/transform.py index b83818d..2659b40 100644 --- a/vibromaf/signal/transform.py +++ b/vibromaf/signal/transform.py @@ -1,6 +1,7 @@ """Transform module""" import math +import warnings from dataclasses import dataclass from typing import Callable @@ -11,6 +12,21 @@ from vibromaf.signal.spectrum import mag2db +def preprocess_input_signal(distorted: np.array, reference: np.array) -> np.array: + """Verify input signal lengths and prepare distorted signal for the metrics""" + if distorted.size > reference.size: + warnings.warn( + f"Truncating distorted signal {distorted.shape} since longer than reference signal {reference.shape}.", + RuntimeWarning, + ) + return np.resize(distorted, reference.shape) + elif distorted.size < reference.size: + raise ValueError( + f"Distorted signal {distorted.shape} must not be shorter than reference signal {reference.shape}!" + ) + return distorted + + def compute_block_dft(block: np.array) -> np.array: """Compute DFT spectrum using FFT""" block_length = np.size(block) @@ -58,8 +74,8 @@ def divide(self, signal: np.array) -> np.array: def divide_and_normalize(self, signal: np.array) -> np.array: blocks = self.divide(signal) - means = np.apply_along_axis(np.mean, 1, blocks) - stds = np.apply_along_axis(np.std, 1, blocks) + means = np.apply_along_axis(np.mean, 1, blocks).reshape((blocks.shape[0], 1)) + stds = np.apply_along_axis(np.std, 1, blocks).reshape((blocks.shape[0], 1)) return (blocks - means) / stds diff --git a/vibromaf/util/matlab.py b/vibromaf/util/matlab.py index 4e4ddda..4f11b7a 100644 --- a/vibromaf/util/matlab.py +++ b/vibromaf/util/matlab.py @@ -63,7 +63,39 @@ def reshape_per_compression_rate( data: np.array, number_of_compression_levels: int = 17 ) -> np.array: """ - Reshape the data into same compression level per row: + Reshape the data into same compression level per row """ number_of_columns = int(data.size / number_of_compression_levels) return data.reshape((number_of_compression_levels, number_of_columns)) + + +class MatSignalLoader: + """Helper class to load test signals from mat files""" + + def __init__(self, metric: str, codec: str = "VCPWQ"): + self.__reference = load_signal_from_mat( + config.DATA_PATH / "Signals.mat", "Signals" + ) + self.__distorted = load_signal_from_mat( + config.DATA_PATH / f"recsig_{codec}.mat", f"recsig_{codec}" + ) + self.__metric_scores = load_signal_from_mat( + config.DATA_PATH / f"{metric}_{codec}.mat", f"{metric}_{codec}" + ) + + def signal_ids(self): + return range(self.__reference.shape[1]) + + def compression_levels(self): + return range(self.__distorted.shape[0]) + + def load_reference_signal(self, signal_id: int): + return self.__reference[:, signal_id] + + def load_distorted_signal(self, signal_id: int, compression_level: int): + return self.__distorted[compression_level, signal_id].reshape( + -1, + ) + + def load_quality_score(self, signal_id: int, compression_level: int): + return self.__metric_scores[compression_level, signal_id]