Skip to content

Commit

Permalink
add wav example
Browse files Browse the repository at this point in the history
  • Loading branch information
hofbi committed Feb 28, 2022
1 parent 960e2e1 commit c41c524
Show file tree
Hide file tree
Showing 19 changed files with 288 additions and 58 deletions.
18 changes: 2 additions & 16 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,11 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements-dev.txt
- name: Check
run: pre-commit run --all-files
- name: Test
run: make test
- name: Build Package
run: make package
- name: Check Package
run: make check_dist

# super-lint:
# runs-on: ubuntu-latest
# steps:
# - uses: actions/checkout@v2
# - name: Lint Code Base
# uses: docker://github/super-linter:v4.8.5
# env:
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# VALIDATE_BASH: false # Already configured in pre-commit
# VALIDATE_DOCKERFILE_HADOLINT: false
# VALIDATE_PYTHON_FLAKE8: false # Already configured in pre-commit
# VALIDATE_PYTHON_PYLINT: false # Already configured in pre-commit
# VALIDATE_PYTHON_ISORT: false # Already configured in pre-commit
- name: Smoke Test
run: make smoke_test
12 changes: 9 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
default_stages: [commit]

ci:
autoupdate_commit_msg: "chore(deps): pre-commit.ci autoupdate"
autoupdate_schedule: "monthly"
autofix_commit_msg: "style: pre-commit.ci fixes"

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- repo: https://github.com/psf/black
rev: 21.12b0
rev: 22.1.0
hooks:
- id: black
language_version: python3
- id: black-jupyter
language_version: python3
- repo: https://github.com/asottile/blacken-docs
rev: v1.12.0
rev: v1.12.1
hooks:
- id: blacken-docs
- repo: https://github.com/PyCQA/flake8
Expand All @@ -35,6 +41,6 @@ repos:
- id: pyupgrade
args: [--py37-plus]
- repo: https://github.com/shellcheck-py/shellcheck-py
rev: v0.8.0.3
rev: v0.8.0.4
hooks:
- id: shellcheck
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,10 @@ deploy: package check_dist

clean:
rm -rf site htmlcov dist vibromaf.egg-info

smoke_test:
mv dist/vibromaf-*.tar.gz dist/vibromaf.tar.gz || true
pip3 install -U dist/vibromaf.tar.gz
cd examples && python3 white_noise.py

smoke_test_clean: clean package smoke_test
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# VibroMAF - Vibrotactile Multi-Method Assessment Fusion

[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/hofbi/vibromaf/main.svg)](https://results.pre-commit.ci/latest/github/hofbi/vibromaf/main)
[![Actions Status](https://github.com/hofbi/vibromaf/workflows/CI/badge.svg)](https://github.com/hofbi/vibromaf)
[![Actions Status](https://github.com/hofbi/vibromaf/workflows/Docs/badge.svg)](https://hofbi.github.io/vibromaf)
[![Actions Status](https://github.com/hofbi/vibromaf/workflows/CodeQL/badge.svg)](https://github.com/hofbi/vibromaf)
Expand Down
9 changes: 7 additions & 2 deletions docs/get-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ from vibromaf.metrics.stsim import st_sim

st_sim_score = st_sim(sample_distorted_signal, sample_reference_signal)

print(st_sim_score) # Should be around 0.85
print(st_sim_score) # Should be around 0.81
```

Find further details how to use this metric at [ST-SIM](metrics/stsim.md).
Expand All @@ -48,7 +48,12 @@ from vibromaf.metrics.spqi import spqi

spqi_score = spqi(sample_distorted_signal, sample_reference_signal)

print(spqi_score) # Should be around 1.0
print(spqi_score) # Should be around 0.43
```

Find further details how to use this metric at [SPQI](metrics/spqi.md).

## Full Example

Find the full example in `examples/white_noise.py`.
All examples available can be found in the `examples` folder.
50 changes: 50 additions & 0 deletions examples/evaluate_wav.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Example to evaluate WAV files"""

from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, FileType

from scipy.io import wavfile

from vibromaf.metrics.snr import snr
from vibromaf.metrics.spqi import spqi
from vibromaf.metrics.stsim import st_sim


def parse_arguments():
"""Parse command line arguments"""
parser = ArgumentParser(
description=__doc__,
formatter_class=ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"distorted",
type=FileType("r"),
help="Distorted .wav file",
)
parser.add_argument(
"reference",
type=FileType("r"),
help="Undistorted reference .wav file",
)
return parser.parse_args()


def main():
"""main"""
args = parse_arguments()

distorted_signal = wavfile.read(args.distorted.name)[1]
reference_signal = wavfile.read(args.reference.name)[1]

# Calculate metric scores
snr_score = snr(distorted_signal, reference_signal)
st_sim_score = st_sim(distorted_signal, reference_signal)
spqi_score = spqi(distorted_signal, reference_signal)

# Print individual metric scores
print(f"SNR score: {snr_score}")
print(f"ST-SIM score: {st_sim_score}")
print(f"SPQI score: {spqi_score}")


if __name__ == "__main__":
main()
21 changes: 21 additions & 0 deletions examples/white_noise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Simple get started example with white noise signals"""

import numpy as np

from vibromaf.metrics.snr import snr
from vibromaf.metrics.spqi import spqi
from vibromaf.metrics.stsim import st_sim

# Define sample signals
sample_reference_signal = np.ones(1000) * 1000 + np.random.randn(1000)
sample_distorted_signal = sample_reference_signal + np.random.randn(1000)

# Calculate metric scores
snr_score = snr(sample_distorted_signal, sample_reference_signal)
st_sim_score = st_sim(sample_distorted_signal, sample_reference_signal)
spqi_score = spqi(sample_distorted_signal, sample_reference_signal)

# Print individual metric scores
print(f"SNR score: {snr_score}")
print(f"ST-SIM score: {st_sim_score}")
print(f"SPQI score: {spqi_score}")
15 changes: 14 additions & 1 deletion test/metrics/test_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,20 @@ def test_snr__sample_signals(self):

result = snr(sample_distorted_signal, sample_reference_signal)

self.assertAlmostEqual(60, result, delta=1.0)
self.assertAlmostEqual(60, result, delta=2.0)

def test_snr__dist_larger_than_ref__dist_should_be_truncated_and_warn(self):
signal = np.array([0, 1])
dist = np.array([0, 1, 2])
with self.assertWarnsRegex(RuntimeWarning, r"Truncating distorted signal"):
result = snr(dist, signal)
self.assertEqual(np.Inf, result)

def test_snr__dist_shorter_than_ref__should_throw(self):
signal = np.array([0, 1, 2])
dist = np.array([0, 1])
with self.assertRaisesRegex(ValueError, r"Distorted .* must not be shorter"):
snr(dist, signal)

def test_nsnr__snr_larger_than_max__should_be_1(self):
signal = np.array([0, 1])
Expand Down
24 changes: 23 additions & 1 deletion test/metrics/test_spqi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,26 @@ def test_spqi_wrapper__sample_signals(self):

result = spqi(sample_distorted_signal, sample_reference_signal)

self.assertAlmostEqual(1.0, result)
self.assertGreaterEqual(1, result)
self.assertGreaterEqual(result, 0)

def test_spqi__truncated_signals_identical__dist_should_be_truncated(self):
signal = np.linspace(0, 1, 2800)
dist = np.append(np.linspace(0, 1, 2800), np.ones(300))
with self.assertWarnsRegex(RuntimeWarning, r"Truncating distorted signal"):
result = spqi(dist, signal)
self.assertEqual(1, result)

def test_spqi__dist_larger_than_ref__dist_should_be_truncated(self):
signal = np.linspace(0, 1, 2800)
dist = np.append(np.ones(300), np.linspace(0, 1, 2800))
with self.assertWarnsRegex(RuntimeWarning, r"Truncating distorted signal"):
result = spqi(dist, signal)
self.assertGreaterEqual(1, result)
self.assertGreaterEqual(result, 0)

def test_spqi__dist_shorter_than_ref__should_throw(self):
signal = np.append(np.linspace(0, 1, 2800), np.ones(300))
dist = np.linspace(0, 1, 2800)
with self.assertRaisesRegex(ValueError, r"Distorted .* must not be shorter"):
spqi(dist, signal)
52 changes: 44 additions & 8 deletions test/metrics/test_stsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,51 @@ def test_st_sim_wrapper__sample_signals(self):

result = st_sim(sample_distorted_signal, sample_reference_signal)

self.assertAlmostEqual(0.85, result, delta=0.1)
self.assertGreaterEqual(1, result)
self.assertGreaterEqual(result, 0)

def test_compute_block_sim__one_block_zero__zero(self):
ref_block = np.ones(4)
dist_block = np.zeros(4)
result = STSIM.compute_block_sim(ref_block, dist_block)
def test_st_sim__truncated_signals_identical__dist_should_be_truncated(self):
signal = np.linspace(0, 1, 2800)
dist = np.append(np.linspace(0, 1, 2800), np.ones(300))
with self.assertWarnsRegex(RuntimeWarning, r"Truncating distorted signal"):
result = st_sim(dist, signal)
self.assertEqual(1, result)

def test_st_sim__dist_larger_than_ref__dist_should_be_truncated(self):
signal = np.linspace(0, 1, 2800)
dist = np.append(np.ones(300), np.linspace(0, 1, 2800))
with self.assertWarnsRegex(RuntimeWarning, r"Truncating distorted signal"):
result = st_sim(dist, signal)
self.assertGreaterEqual(1, result)
self.assertGreaterEqual(result, 0)

def test_st_sim__dist_shorter_than_ref__should_throw(self):
signal = np.append(np.linspace(0, 1, 2800), np.ones(300))
dist = np.linspace(0, 1, 2800)
with self.assertRaisesRegex(ValueError, r"Distorted .* must not be shorter"):
st_sim(dist, signal)

def test_compute_sim__one_block_zero__zero(self):
ref_block = np.ones((4, 2))
dist_block = np.zeros((4, 2))
result = STSIM.compute_sim(ref_block, dist_block)
self.assertEqual(0, result)

def test_compute_block_sim__blocks_identical__one(self):
block = np.array([1, 2, 3, 4])
result = STSIM.compute_block_sim(block, block)
def test_compute_sim__blocks_identical__one(self):
block = np.ones((4, 3))
result = STSIM.compute_sim(block, block)
self.assertEqual(1, result)

def test_compute_sim__values_larger_than_one_possible(self):
ref_block = np.array([[0.5] * 4])
dist_block = np.array([[0.5, 0.5, 0.5, 1]])
result = STSIM.compute_sim(ref_block, dist_block)
self.assertEqual(1.25, result)

def test_st_sim_init__eta_grater_one__should_throw(self):
with self.assertRaisesRegex(ValueError, "Eta must be between 0 and 1."):
STSIM(eta=1.1)

def test_st_sim_init__eta_negative__should_throw(self):
with self.assertRaisesRegex(ValueError, "Eta must be between 0 and 1."):
STSIM(eta=-0.1)
3 changes: 2 additions & 1 deletion test/metrics/test_vibromaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ def test_vibromaf_wrapper__sample_signals__some_output_value(self):
sample_distorted_signal, sample_reference_signal, Path("test-model.pickle")
)

self.assertAlmostEqual(0.16, result, delta=0.1)
self.assertGreaterEqual(1, result)
self.assertGreaterEqual(result, 0)
11 changes: 6 additions & 5 deletions test/signal/test_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,18 @@ def test_mag2db(self):
def test_compute_normalized_spectral_difference__same_signals_should_be_minus_inf(
self,
):
signal = np.ones(10)
signal = np.ones((10, 2))
result = compute_normalized_spectral_difference(signal, signal)
self.assertEqual(-np.inf, result)
self.assertListEqual([-np.inf] * 10, list(result))

def test_compute_normalized_spectral_difference__different_signals_should_be_positive(
self,
):
signal_one = np.ones(10)
signal_two = np.zeros(10)
signal_one = np.ones((2, 10))
signal_two = np.zeros((2, 10))
result = compute_normalized_spectral_difference(signal_one, signal_two)
self.assertAlmostEqual(1.542, result, delta=0.001)
self.assertGreaterEqual(0, result[0])
self.assertEqual(2, result.size)

def test_compute_spectral_support__zeros_array__array_with_0p5(self):
spectrum = np.zeros((2, 4))
Expand Down
27 changes: 27 additions & 0 deletions test/signal/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
compute_block_dct,
compute_block_dft,
cut_off_strategy,
preprocess_input_signal,
zero_padding_strategy,
)

Expand Down Expand Up @@ -87,6 +88,24 @@ def test_zero_padding_strategy__signal_equal_to_block_length__unchanged(
result = zero_padding_strategy(input_signal, 3)
self.assertListEqual([1, 2, 3], list(result))

def test_preprocess_input_signal__dist_larger_than_ref__dist_should_be_truncated_and_warn(
self,
):
signal = np.array([0, 1])
dist = np.array([0, 1, 2])
with self.assertWarnsRegex(
RuntimeWarning,
r"Truncating distorted signal .* since longer than reference",
):
result = preprocess_input_signal(dist, signal)
self.assertListEqual(list(signal), list(result))

def test_preprocess_input_signal__dist_shorter_than_ref__should_throw(self):
signal = np.array([0, 1, 2])
dist = np.array([0, 1])
with self.assertRaisesRegex(ValueError, r"Distorted .* must not be shorter"):
preprocess_input_signal(dist, signal)


class BlockBuilderTest(unittest.TestCase):
"""Block Builder Test"""
Expand Down Expand Up @@ -131,6 +150,14 @@ def test_divide_and_normalize__periodic_array__array_divided_by_two_and_shifted(
result = unit.divide_and_normalize(input_signal)
self.assertListEqual([-1, 1, -1, 1], list(result[0]))

def test_divide_and_normalize__multiple_blocks__correct_reshaped(self):
input_signal = np.array([-2, 2, -2, 2, -2, 2, -2, 2])
unit = BlockBuilder(2)
result = unit.divide_and_normalize(input_signal)
self.assertTrue(
np.array_equal(np.array([[-1, 1], [-1, 1], [-1, 1], [-1, 1]]), result)
)


class PerceptualSpectrumBuilderTest(unittest.TestCase):
"""Perceptual Spectrum Builder Test"""
Expand Down
2 changes: 2 additions & 0 deletions vibromaf/metrics/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from vibromaf.signal.spectrum import pow2db, signal_energy
from vibromaf.signal.transform import preprocess_input_signal


def snr(distorted: np.array, reference: np.array) -> float:
Expand All @@ -17,6 +18,7 @@ def snr(distorted: np.array, reference: np.array) -> float:
-------
* `float` The SNR.
"""
distorted = preprocess_input_signal(distorted, reference)
return pow2db(signal_energy(reference) / signal_energy(distorted - reference))


Expand Down
Loading

0 comments on commit c41c524

Please sign in to comment.