11import inspect
22import warnings
3+ from operator import eq , lt
34
45import naive
56import numpy as np
89
910from stumpy import sdp
1011
11- test_data = [
12- (np .array ([- 1 , 1 , 2 ], dtype = np .float64 ), np .array (range (5 ), dtype = np .float64 )),
12+ # README
13+ # Real FFT algorithm performs more efficiently when the length
14+ # of the input array `arr` is composed of small prime factors.
15+ # The next_fast_len(arr, real=True) function from Scipy returns
16+ # the same length if len(arr) is composed of a subset of
17+ # prime numbers 2, 3, 5. Therefore, these radices are
18+ # considered as the most efficient for the real FFT algorithm.
19+
20+ # To ensure that the tests cover different cases, the following cases
21+ # are considered:
22+ # 1. len(T) is even, and len(T) == next_fast_len(len(T), real=True)
23+ # 2. len(T) is odd, and len(T) == next_fast_len(len(T), real=True)
24+ # 3. len(T) is even, and len(T) < next_fast_len(len(T), real=True)
25+ # 4. len(T) is odd, and len(T) < next_fast_len(len(T), real=True)
26+ # And 5. a special case of 1, where len(T) is power of 2.
27+
28+ # Therefore:
29+ # 1. len(T) is composed of 2 and a subset of {3, 5}
30+ # 2. len(T) is composed of a subset of {3, 5}
31+ # 3. len(T) is composed of a subset of {7, 11, 13, ...} and 2
32+ # 4. len(T) is composed of a subset of {7, 11, 13, ...}
33+ # 5. len(T) is power of 2
34+
35+ # In some cases, the prime factors are raised to a power of
36+ # certain degree to increase the length of array to be around
37+ # 1000-2000. This allows us to test sliding_dot_product for
38+ # wider range of query lengths.
39+
40+ test_inputs = [
41+ # Input format:
42+ # (
43+ # len(T),
44+ # remainder, # from `len(T) % 2`
45+ # comparator, # for len(T) comparator next_fast_len(len(T), real=True)
46+ # )
1347 (
14- np .array ([9 , 8100 , - 60 ], dtype = np .float64 ),
15- np .array ([584 , - 11 , 23 , 79 , 1001 ], dtype = np .float64 ),
16- ),
17- (np .random .uniform (- 1000 , 1000 , [8 ]), np .random .uniform (- 1000 , 1000 , [64 ])),
48+ 2 * (3 ** 2 ) * (5 ** 3 ),
49+ 0 ,
50+ eq ,
51+ ), # = 2250, Even `len(T)`, and `len(T) == next_fast_len(len(T), real=True)`
52+ (
53+ (3 ** 2 ) * (5 ** 3 ),
54+ 1 ,
55+ eq ,
56+ ), # = 1125, Odd `len(T)`, and `len(T) == next_fast_len(len(T), real=True)`.
57+ (
58+ 2 * 7 * 11 * 13 ,
59+ 0 ,
60+ lt ,
61+ ), # = 2002, Even `len(T)`, and `len(T) < next_fast_len(len(T), real=True)`
62+ (
63+ 7 * 11 * 13 ,
64+ 1 ,
65+ lt ,
66+ ), # = 1001, Odd `len(T)`, and `len(T) < next_fast_len(len(T), real=True)`
1867]
1968
2069
@@ -27,15 +76,78 @@ def get_sdp_function_names():
2776 return out
2877
2978
30- @pytest .mark .parametrize ("Q, T" , test_data )
31- def test_sliding_dot_product (Q , T ):
79+ @pytest .mark .parametrize ("n_T, remainder, comparator" , test_inputs )
80+ def test_sdp (n_T , remainder , comparator ):
81+ # test_sdp for cases 1-4
82+
83+ n_Q_prime = [
84+ 2 ,
85+ 3 ,
86+ 5 ,
87+ 7 ,
88+ 11 ,
89+ 13 ,
90+ 17 ,
91+ 19 ,
92+ 23 ,
93+ 29 ,
94+ 31 ,
95+ 37 ,
96+ 41 ,
97+ 43 ,
98+ 47 ,
99+ 53 ,
100+ 59 ,
101+ 61 ,
102+ 67 ,
103+ 71 ,
104+ 73 ,
105+ 79 ,
106+ 83 ,
107+ 89 ,
108+ 97 ,
109+ ]
110+ n_Q_power2 = [2 , 4 , 8 , 16 , 32 , 64 ]
111+ n_Q_values = n_Q_prime + n_Q_power2 + [n_T ]
112+ n_Q_values = sorted (n_Q for n_Q in set (n_Q_values ) if n_Q <= n_T )
113+
114+ for n_Q in n_Q_values :
115+ Q = np .random .rand (n_Q )
116+ T = np .random .rand (n_T )
117+ ref = naive .rolling_window_dot_product (Q , T )
118+ for func_name in get_sdp_function_names ():
119+ func = getattr (sdp , func_name )
120+ try :
121+ comp = func (Q , T )
122+ npt .assert_allclose (comp , ref )
123+ except Exception as e : # pragma: no cover
124+ msg = f"Error in { func_name } , with n_Q={ len (Q )} and n_T={ len (T )} "
125+ warnings .warn (msg )
126+ raise e
127+
128+
129+ def test_sdp_power2 ():
130+ # test for case 5. len(T) is power of 2
131+ pmin = 3
132+ pmax = 13
133+
32134 for func_name in get_sdp_function_names ():
33135 func = getattr (sdp , func_name )
34136 try :
35- comp = func (Q , T )
36- ref = naive .rolling_window_dot_product (Q , T )
37- npt .assert_allclose (comp , ref )
137+ for q in range (pmin , pmax + 1 ):
138+ n_Q = 2 ** q
139+ for p in range (q , pmax + 1 ):
140+ n_T = 2 ** p
141+ Q = np .random .rand (n_Q )
142+ T = np .random .rand (n_T )
143+
144+ ref = naive .rolling_window_dot_product (Q , T )
145+ comp = func (Q , T )
146+ npt .assert_allclose (comp , ref )
147+
38148 except Exception as e : # pragma: no cover
39- msg = f"Error in { func_name } , with n_Q= { len ( Q ) } and n_T= { len ( T ) } "
149+ msg = f"Error in { func_name } , with q= { q } and p= { p } "
40150 warnings .warn (msg )
41151 raise e
152+
153+ return
0 commit comments