Skip to content

Commit 8b6191b

Browse files
committed
enhanced test cases
1 parent 05b073c commit 8b6191b

File tree

1 file changed

+124
-12
lines changed

1 file changed

+124
-12
lines changed

tests/test_sdp.py

Lines changed: 124 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import inspect
22
import warnings
3+
from operator import eq, lt
34

45
import naive
56
import numpy as np
@@ -8,13 +9,61 @@
89

910
from stumpy import sdp
1011

11-
test_data = [
12-
(np.array([-1, 1, 2], dtype=np.float64), np.array(range(5), dtype=np.float64)),
12+
# README
13+
# Real FFT algorithm performs more efficiently when the length
14+
# of the input array `arr` is composed of small prime factors.
15+
# The next_fast_len(arr, real=True) function from Scipy returns
16+
# the same length if len(arr) is composed of a subset of
17+
# prime numbers 2, 3, 5. Therefore, these radices are
18+
# considered as the most efficient for the real FFT algorithm.
19+
20+
# To ensure that the tests cover different cases, the following cases
21+
# are considered:
22+
# 1. len(T) is even, and len(T) == next_fast_len(len(T), real=True)
23+
# 2. len(T) is odd, and len(T) == next_fast_len(len(T), real=True)
24+
# 3. len(T) is even, and len(T) < next_fast_len(len(T), real=True)
25+
# 4. len(T) is odd, and len(T) < next_fast_len(len(T), real=True)
26+
# And 5. a special case of 1, where len(T) is power of 2.
27+
28+
# Therefore:
29+
# 1. len(T) is composed of 2 and a subset of {3, 5}
30+
# 2. len(T) is composed of a subset of {3, 5}
31+
# 3. len(T) is composed of a subset of {7, 11, 13, ...} and 2
32+
# 4. len(T) is composed of a subset of {7, 11, 13, ...}
33+
# 5. len(T) is power of 2
34+
35+
# In some cases, the prime factors are raised to a power of
36+
# certain degree to increase the length of array to be around
37+
# 1000-2000. This allows us to test sliding_dot_product for
38+
# wider range of query lengths.
39+
40+
test_inputs = [
41+
# Input format:
42+
# (
43+
# len(T),
44+
# remainder, # from `len(T) % 2`
45+
# comparator, # for len(T) comparator next_fast_len(len(T), real=True)
46+
# )
1347
(
14-
np.array([9, 8100, -60], dtype=np.float64),
15-
np.array([584, -11, 23, 79, 1001], dtype=np.float64),
16-
),
17-
(np.random.uniform(-1000, 1000, [8]), np.random.uniform(-1000, 1000, [64])),
48+
2 * (3**2) * (5**3),
49+
0,
50+
eq,
51+
), # = 2250, Even `len(T)`, and `len(T) == next_fast_len(len(T), real=True)`
52+
(
53+
(3**2) * (5**3),
54+
1,
55+
eq,
56+
), # = 1125, Odd `len(T)`, and `len(T) == next_fast_len(len(T), real=True)`.
57+
(
58+
2 * 7 * 11 * 13,
59+
0,
60+
lt,
61+
), # = 2002, Even `len(T)`, and `len(T) < next_fast_len(len(T), real=True)`
62+
(
63+
7 * 11 * 13,
64+
1,
65+
lt,
66+
), # = 1001, Odd `len(T)`, and `len(T) < next_fast_len(len(T), real=True)`
1867
]
1968

2069

@@ -27,15 +76,78 @@ def get_sdp_function_names():
2776
return out
2877

2978

30-
@pytest.mark.parametrize("Q, T", test_data)
31-
def test_sliding_dot_product(Q, T):
79+
@pytest.mark.parametrize("n_T, remainder, comparator", test_inputs)
80+
def test_sdp(n_T, remainder, comparator):
81+
# test_sdp for cases 1-4
82+
83+
n_Q_prime = [
84+
2,
85+
3,
86+
5,
87+
7,
88+
11,
89+
13,
90+
17,
91+
19,
92+
23,
93+
29,
94+
31,
95+
37,
96+
41,
97+
43,
98+
47,
99+
53,
100+
59,
101+
61,
102+
67,
103+
71,
104+
73,
105+
79,
106+
83,
107+
89,
108+
97,
109+
]
110+
n_Q_power2 = [2, 4, 8, 16, 32, 64]
111+
n_Q_values = n_Q_prime + n_Q_power2 + [n_T]
112+
n_Q_values = sorted(n_Q for n_Q in set(n_Q_values) if n_Q <= n_T)
113+
114+
for n_Q in n_Q_values:
115+
Q = np.random.rand(n_Q)
116+
T = np.random.rand(n_T)
117+
ref = naive.rolling_window_dot_product(Q, T)
118+
for func_name in get_sdp_function_names():
119+
func = getattr(sdp, func_name)
120+
try:
121+
comp = func(Q, T)
122+
npt.assert_allclose(comp, ref)
123+
except Exception as e: # pragma: no cover
124+
msg = f"Error in {func_name}, with n_Q={len(Q)} and n_T={len(T)}"
125+
warnings.warn(msg)
126+
raise e
127+
128+
129+
def test_sdp_power2():
130+
# test for case 5. len(T) is power of 2
131+
pmin = 3
132+
pmax = 13
133+
32134
for func_name in get_sdp_function_names():
33135
func = getattr(sdp, func_name)
34136
try:
35-
comp = func(Q, T)
36-
ref = naive.rolling_window_dot_product(Q, T)
37-
npt.assert_allclose(comp, ref)
137+
for q in range(pmin, pmax + 1):
138+
n_Q = 2**q
139+
for p in range(q, pmax + 1):
140+
n_T = 2**p
141+
Q = np.random.rand(n_Q)
142+
T = np.random.rand(n_T)
143+
144+
ref = naive.rolling_window_dot_product(Q, T)
145+
comp = func(Q, T)
146+
npt.assert_allclose(comp, ref)
147+
38148
except Exception as e: # pragma: no cover
39-
msg = f"Error in {func_name}, with n_Q={len(Q)} and n_T={len(T)}"
149+
msg = f"Error in {func_name}, with q={q} and p={p}"
40150
warnings.warn(msg)
41151
raise e
152+
153+
return

0 commit comments

Comments
 (0)