Skip to content

Commit 2837609

Browse files
committed
move convolve_sdp and add test
1 parent d0b2375 commit 2837609

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

stumpy/sdp.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from numba import njit
3+
from scipy.signal import convolve
34

45
from . import config
56

@@ -31,10 +32,34 @@ def _njit_sliding_dot_product(Q, T):
3132
return out
3233

3334

35+
def _convolve_sliding_dot_product(Q, T):
36+
"""
37+
Use (direct or FFT) convolution to calculate the sliding window dot product.
38+
39+
Parameters
40+
----------
41+
Q : numpy.ndarray
42+
Query array or subsequence
43+
44+
T : numpy.ndarray
45+
Time series or sequence
46+
47+
Returns
48+
-------
49+
output : numpy.ndarray
50+
Sliding dot product between `Q` and `T`.
51+
"""
52+
n = T.shape[0]
53+
m = Q.shape[0]
54+
Qr = np.flipud(Q) # Reverse/flip Q
55+
QT = convolve(Qr, T)
56+
57+
return QT.real[m - 1 : n]
58+
59+
3460
def _sliding_dot_product(Q, T):
3561
"""
36-
A wrapper function for the Numba JIT-compiled implementation of the sliding
37-
window dot product.
62+
Compute the sliding dot product between `Q` and `T`
3863
3964
Parameters
4065
----------
@@ -49,4 +74,4 @@ def _sliding_dot_product(Q, T):
4974
out : numpy.ndarray
5075
Sliding dot product between `Q` and `T`.
5176
"""
52-
return _njit_sliding_dot_product(Q, T)
77+
return _convolve_sliding_dot_product(Q, T)

tests/test_sdp.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ def test_njit_sliding_dot_product(Q, T):
3030
npt.assert_almost_equal(ref_mp, comp_mp)
3131

3232

33+
@pytest.mark.parametrize("Q, T", test_data)
34+
def test_convolve_sliding_dot_product(Q, T):
35+
ref_mp = naive_rolling_window_dot_product(Q, T)
36+
comp_mp = sdp._convolve_sliding_dot_product(Q, T)
37+
npt.assert_almost_equal(ref_mp, comp_mp)
38+
39+
3340
@pytest.mark.parametrize("Q, T", test_data)
3441
def test_sliding_dot_product(Q, T):
3542
ref_mp = naive_rolling_window_dot_product(Q, T)

0 commit comments

Comments
 (0)