Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pywt/tests/test_cwt_wavelets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
assert_almost_equal,
assert_equal,
assert_raises,
assert_warns,
)

import pywt
Expand Down Expand Up @@ -341,7 +340,8 @@ def test_cwt_parameters_in_names():
for func in [pywt.ContinuousWavelet, pywt.DiscreteContinuousWavelet]:
for name in ['fbsp', 'cmor', 'shan']:
# additional parameters should be specified within the name
assert_warns(FutureWarning, func, name)
with pytest.warns(FutureWarning):
func(name)

for name in ['cmor', 'shan']:
# valid names
Expand Down
24 changes: 16 additions & 8 deletions pywt/tests/test_deprecations.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,40 @@
import warnings

import numpy as np
from numpy.testing import assert_array_equal, assert_warns
import pytest
from numpy.testing import assert_array_equal

import pywt


def test_intwave_deprecation():
wavelet = pywt.Wavelet('db3')
assert_warns(DeprecationWarning, pywt.intwave, wavelet)
with pytest.warns(DeprecationWarning):
pywt.intwave(wavelet)


def test_centrfrq_deprecation():
wavelet = pywt.Wavelet('db3')
assert_warns(DeprecationWarning, pywt.centrfrq, wavelet)
with pytest.warns(DeprecationWarning):
pywt.centrfrq(wavelet)


def test_scal2frq_deprecation():
wavelet = pywt.Wavelet('db3')
assert_warns(DeprecationWarning, pywt.scal2frq, wavelet, 1)
with pytest.warns(DeprecationWarning):
pywt.scal2frq(wavelet, 1)


def test_orthfilt_deprecation():
assert_warns(DeprecationWarning, pywt.orthfilt, range(6))
with pytest.warns(DeprecationWarning):
pywt.orthfilt(range(6))


def test_integrate_wave_tuple():
sig = [0, 1, 2, 3]
xgrid = [0, 1, 2, 3]
assert_warns(DeprecationWarning, pywt.integrate_wavelet, (sig, xgrid))
with pytest.warns(DeprecationWarning):
pywt.integrate_wavelet((sig, xgrid))


old_modes = ['zpd',
Expand All @@ -42,15 +48,17 @@ def test_integrate_wave_tuple():

def test_MODES_from_object_deprecation():
for mode in old_modes:
assert_warns(DeprecationWarning, pywt.Modes.from_object, mode)
with pytest.warns(DeprecationWarning):
pywt.Modes.from_object(mode)


def test_MODES_attributes_deprecation():
def get_mode(Modes, name):
return getattr(Modes, name)

for mode in old_modes:
assert_warns(DeprecationWarning, get_mode, pywt.Modes, mode)
with pytest.warns(DeprecationWarning):
get_mode(pywt.Modes, mode)


def test_mode_equivalence():
Expand Down
10 changes: 5 additions & 5 deletions pywt/tests/test_multilevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
assert_equal,
assert_raises,
assert_raises_regex,
assert_warns,
)

import pywt
Expand Down Expand Up @@ -899,8 +898,9 @@ def test_fswavedecn_fswaverecn_variable_levels():
assert_raises(ValueError, pywt.fswavedecn, data, 'haar', levels=(1, 1, 1, 1))

# levels too large for array size
assert_warns(UserWarning, pywt.fswavedecn, data, 'haar',
levels=int(np.log2(np.min(data.shape)))+1)
with pytest.warns(UserWarning):
pywt.fswavedecn(data, 'haar',
levels=int(np.log2(np.min(data.shape)))+1)


def test_fswavedecn_fswaverecn_variable_wavelets_and_modes():
Expand Down Expand Up @@ -967,8 +967,8 @@ def test_fswavedecnresult():
k, np.zeros(tuple([s + 1 for s in d.shape])))

# warns on assigning with a non-matching dtype
assert_warns(UserWarning, result.__setitem__,
k, np.zeros_like(d).astype(np.float32))
with pytest.warns(UserWarning):
result.__setitem__(k, np.zeros_like(d).astype(np.float32))

# all coefficients are stacked into result.coeffs (same ndim)
assert_equal(result.coeffs.ndim, data.ndim)
Expand Down
81 changes: 50 additions & 31 deletions pywt/tests/test_swt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
assert_allclose,
assert_array_equal,
assert_equal,
assert_raises,
assert_warns,
)

import pywt
Expand Down Expand Up @@ -69,7 +67,9 @@ def test_swt_decomposition():

def test_swt_max_level():
# odd sized signal will warn about no levels of decomposition possible
assert_warns(UserWarning, pywt.swt_max_level, 11)
with pytest.warns(UserWarning):
pywt.swt_max_level(11)

with warnings.catch_warnings():
warnings.simplefilter('ignore', UserWarning)
assert_equal(pywt.swt_max_level(11), 0)
Expand Down Expand Up @@ -134,7 +134,8 @@ def test_swt_axis():
assert_array_equal(row, cD2)

# axis too large
assert_raises(ValueError, pywt.swt, x, db1, level=2, axis=5)
with pytest.raises(ValueError):
pywt.swt(x, db1, level=2, axis=5)


def test_swt_iswt_integration():
Expand Down Expand Up @@ -217,9 +218,8 @@ def test_swt_default_level_by_axis():

def test_swt2_ndim_error():
x = np.ones(8)
with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning)
assert_raises(ValueError, pywt.swt2, x, 'haar', level=1)
with pytest.raises(ValueError):
pywt.swt2(x, 'haar', level=1)


@pytest.mark.slow
Expand Down Expand Up @@ -298,10 +298,12 @@ def test_swt2_axes():
assert_allclose(X, r2, atol=atol)

# duplicate axes not allowed
assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1,
axes=(0, 0))
with pytest.raises(ValueError):
pywt.swt2(X, current_wavelet, 1, axes=(0, 0))

# too few axes
assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1, axes=(0, ))
with pytest.raises(ValueError):
pywt.swt2(X, current_wavelet, 1, axes=(0, ))


def test_swtn_axes():
Expand All @@ -325,21 +327,24 @@ def test_swtn_axes():
assert_equal(empty, [])

# duplicate axes not allowed
assert_raises(ValueError, pywt.swtn, X, current_wavelet, 1, axes=(0, 0))
with pytest.raises(ValueError):
pywt.swtn(X, current_wavelet, 1, axes=(0, 0))

# data.ndim = 0
assert_raises(ValueError, pywt.swtn, np.asarray([]), current_wavelet, 1)
with pytest.raises(ValueError):
pywt.swtn(np.asarray([]), current_wavelet, 1)

# start_level too large
assert_raises(ValueError, pywt.swtn, X, current_wavelet,
level=1, start_level=2)
with pytest.raises(ValueError):
pywt.swtn(X, current_wavelet, level=1, start_level=2)

# level < 1 in swt_axis call
assert_raises(ValueError, swt_axis, X, current_wavelet, level=0,
start_level=0)
with pytest.raises(ValueError):
swt_axis(X, current_wavelet, level=0, start_level=0)

# odd-sized data not allowed
assert_raises(ValueError, swt_axis, X[:-1, :], current_wavelet, level=0,
start_level=0, axis=0)
with pytest.raises(ValueError):
swt_axis( X[-1, :], current_wavelet, level=0, start_level=0, axis=0)


@pytest.mark.slow
Expand Down Expand Up @@ -401,12 +406,17 @@ def test_iswtn_errors():
coeffs = pywt.swtn(x, w, max_level, axes=axes)

# more axes than dimensions transformed
assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 1, 2))
with pytest.raises(ValueError):
pywt.iswtn(coeffs, w, axes=(0, 1, 2))

# duplicate axes not allowed
assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 0))
with pytest.raises(ValueError):
pywt.iswtn(coeffs, w, axes=(0, 0))

# mismatched coefficient size
coeffs[0]['da'] = coeffs[0]['da'][:-1, :]
assert_raises(RuntimeError, pywt.iswtn, coeffs, w, axes=axes)
with pytest.raises(RuntimeError):
pywt.iswtn(coeffs, w, axes=axes)


def test_swtn_iswtn_unique_shape_per_axis():
Expand Down Expand Up @@ -441,8 +451,11 @@ def test_per_axis_wavelets():
assert_allclose(pywt.iswtn(coefs, wavelets[:1]), data, atol=1e-14)

# length of wavelets doesn't match the length of axes
assert_raises(ValueError, pywt.swtn, data, wavelets[:2], level)
assert_raises(ValueError, pywt.iswtn, coefs, wavelets[:2])
with pytest.raises(ValueError):
pywt.swtn(data, wavelets[:2], level)

with pytest.raises(ValueError):
pywt.iswtn(coefs, wavelets[:2])

with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning)
Expand All @@ -458,11 +471,12 @@ def test_error_on_continuous_wavelet():
for dec_func, rec_func in zip([pywt.swt, pywt.swt2, pywt.swtn],
[pywt.iswt, pywt.iswt2, pywt.iswtn]):
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
assert_raises(ValueError, dec_func, data, wavelet=cwave,
level=3)
with pytest.raises(ValueError):
dec_func(data, wavelet=cwave, level=3)

c = dec_func(data, 'db1', level=3)
assert_raises(ValueError, rec_func, c, wavelet=cwave)
with pytest.raises(ValueError):
rec_func(c, wavelet=cwave)


def test_iswt_mixed_dtypes():
Expand Down Expand Up @@ -552,11 +566,13 @@ def test_iswtn_mixed_dtypes():

def test_swt_zero_size_axes():
# raise on empty input array
assert_raises(ValueError, pywt.swt, [], 'db2')
with pytest.raises(ValueError):
pywt.swt([], 'db2')

# >1D case uses a different code path so check there as well
x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis
assert_raises(ValueError, pywt.swtn, x, 'db2', level=1, axes=(0,))
with pytest.raises(ValueError):
pywt.swtn(x, 'db2', level=1, axes=(0,))


def test_swt_variance_and_energy_preservation():
Expand All @@ -575,7 +591,8 @@ def test_swt_variance_and_energy_preservation():
np.linalg.norm(np.concatenate(coeffs)))

# non-orthogonal wavelet with norm=True raises a warning
assert_warns(UserWarning, pywt.swt, x, 'bior2.2', norm=True)
with pytest.warns(UserWarning):
pywt.swt(x, 'bior2.2', norm=True)


def test_swt2_variance_and_energy_preservation():
Expand All @@ -598,7 +615,8 @@ def test_swt2_variance_and_energy_preservation():
np.linalg.norm(np.concatenate(coeff_list)))

# non-orthogonal wavelet with norm=True raises a warning
assert_warns(UserWarning, pywt.swt2, x, 'bior2.2', level=4, norm=True)
with pytest.warns(UserWarning):
pywt.swt2(x, 'bior2.2', level=4, norm=True)


def test_swtn_variance_and_energy_preservation():
Expand All @@ -621,7 +639,8 @@ def test_swtn_variance_and_energy_preservation():
np.linalg.norm(np.concatenate(coeff_list)))

# non-orthogonal wavelet with norm=True raises a warning
assert_warns(UserWarning, pywt.swtn, x, 'bior2.2', level=4, norm=True)
with pytest.warns(UserWarning):
pywt.swtn(x, 'bior2.2', level=4, norm=True)


def test_swt_ravel_and_unravel():
Expand Down