Skip to content

Commit d13a32b

Browse files
committed
Refactored the recompile process
1 parent b24ff9d commit d13a32b

File tree

2 files changed

+74
-60
lines changed

2 files changed

+74
-60
lines changed

stumpy/cache.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import site
1010
import warnings
1111

12+
import numba
13+
1214
CACHE_WARNING = "Caching `numba` functions is purely for experimental purposes "
1315
CACHE_WARNING += "and should never be used or depended upon as it is not supported! "
1416
CACHE_WARNING += "All caching capabilities are not tested and may be removed/changed "
@@ -107,3 +109,43 @@ def _get_cache():
107109
site_pkg_dir = site.getsitepackages()[0]
108110
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
109111
return [f.name for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()]
112+
113+
114+
def _recompile(func=None, fastmath=None):
115+
"""
116+
Recompile a jit/njit decorated function. If `func` is None, then it wiil
117+
recompile all njit functions of STUMPY.
118+
119+
Parameters
120+
----------
121+
func : a njit function, default None
122+
The numba function to recompile. If None, then all njit functions
123+
of STUMPY will be recompiled.
124+
125+
fastmath : bool or set, default None
126+
The fastmath flags to use. If None, then the func's fastmath flags
127+
will not be changed. This is only used when `func` is provided.
128+
129+
Returns
130+
-------
131+
None
132+
"""
133+
warnings.warn(CACHE_WARNING)
134+
if func is None:
135+
njit_funcs = get_njit_funcs()
136+
for module_name, func_name in njit_funcs:
137+
module = importlib.import_module(f".{module_name}", package="stumpy")
138+
func = getattr(module, func_name)
139+
func.recompile()
140+
141+
else:
142+
if not numba.extending.is_jitted(func):
143+
msg = "The function `func` must be a (n)jit function."
144+
raise ValueError(msg)
145+
146+
if fastmath is not None:
147+
func.targetoptions["fastmath"] = fastmath
148+
149+
func.recompile()
150+
151+
return

tests/test_precision.py

Lines changed: 32 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import functools
2-
import importlib
32
from unittest.mock import patch
43

54
import naive
@@ -149,75 +148,48 @@ def test_snippets():
149148
cmp_regimes,
150149
) = stumpy.snippets(T, m, k, s=s, mpdist_T_subseq_isconstant=isconstant_custom_func)
151150

152-
if np.allclose(ref_snippets, cmp_snippets) or numba.config.DISABLE_JIT:
153-
npt.assert_almost_equal(
154-
ref_snippets, cmp_snippets, decimal=config.STUMPY_TEST_PRECISION
155-
)
156-
npt.assert_almost_equal(
157-
ref_indices, cmp_indices, decimal=config.STUMPY_TEST_PRECISION
158-
)
159-
npt.assert_almost_equal(
160-
ref_profiles, cmp_profiles, decimal=config.STUMPY_TEST_PRECISION
161-
)
162-
npt.assert_almost_equal(
163-
ref_fractions, cmp_fractions, decimal=config.STUMPY_TEST_PRECISION
164-
)
165-
npt.assert_almost_equal(
166-
ref_areas, cmp_areas, decimal=config.STUMPY_TEST_PRECISION
167-
)
168-
npt.assert_almost_equal(ref_regimes, cmp_regimes)
169-
else:
170-
# Revise fastmath flag, recompile, and re-calculate snippets,
171-
# and then revert the changes
172-
151+
if not np.allclose(ref_snippets, cmp_snippets) and not numba.config.DISABLE_JIT:
152+
# Revise fastmath flags by removing reassoc (to improve precision),
153+
# recompile njit functions, and re-compute snippets.
173154
config.STUMPY_FASTMATH_FLAGS = {"nsz", "arcp", "contract", "afn"}
174-
core._calculate_squared_distance.targetoptions["fastmath"] = (
175-
config.STUMPY_FASTMATH_FLAGS
155+
cache._recompile(
156+
core._calculate_squared_distance, fastmath=config.STUMPY_FASTMATH_FLAGS
176157
)
177-
178-
njit_funcs = cache.get_njit_funcs()
179-
for module_name, func_name in njit_funcs:
180-
module = importlib.import_module(f".{module_name}", package="stumpy")
181-
func = getattr(module, func_name)
182-
func.recompile()
158+
cache._recompile()
183159

184160
(
185-
cmp_snippets_NOreassoc,
186-
cmp_indices_NOreassoc,
187-
cmp_profiles_NOreassoc,
188-
cmp_fractions_NOreassoc,
189-
cmp_areas_NOreassoc,
190-
cmp_regimes_NOreassoc,
161+
cmp_snippets,
162+
cmp_indices,
163+
cmp_profiles,
164+
cmp_fractions,
165+
cmp_areas,
166+
cmp_regimes,
191167
) = stumpy.snippets(
192168
T, m, k, s=s, mpdist_T_subseq_isconstant=isconstant_custom_func
193169
)
194170

195-
config._reset("STUMPY_FASTMATH_FLAGS")
196-
197-
core._calculate_squared_distance.targetoptions["fastmath"] = (
198-
config.STUMPY_FASTMATH_FLAGS
199-
)
200-
for module_name, func_name in njit_funcs:
201-
module = importlib.import_module(f".{module_name}", package="stumpy")
202-
func = getattr(module, func_name)
203-
func.recompile()
171+
npt.assert_almost_equal(
172+
ref_snippets, cmp_snippets, decimal=config.STUMPY_TEST_PRECISION
173+
)
174+
npt.assert_almost_equal(
175+
ref_indices, cmp_indices, decimal=config.STUMPY_TEST_PRECISION
176+
)
177+
npt.assert_almost_equal(
178+
ref_profiles, cmp_profiles, decimal=config.STUMPY_TEST_PRECISION
179+
)
180+
npt.assert_almost_equal(
181+
ref_fractions, cmp_fractions, decimal=config.STUMPY_TEST_PRECISION
182+
)
183+
npt.assert_almost_equal(ref_areas, cmp_areas, decimal=config.STUMPY_TEST_PRECISION)
184+
npt.assert_almost_equal(ref_regimes, cmp_regimes)
204185

205-
npt.assert_almost_equal(
206-
ref_snippets, cmp_snippets_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
207-
)
208-
npt.assert_almost_equal(
209-
ref_indices, cmp_indices_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
210-
)
211-
npt.assert_almost_equal(
212-
ref_profiles, cmp_profiles_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
213-
)
214-
npt.assert_almost_equal(
215-
ref_fractions, cmp_fractions_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
216-
)
217-
npt.assert_almost_equal(
218-
ref_areas, cmp_areas_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
186+
if not numba.config.DISABLE_JIT:
187+
# Revert fastmath flag back to their default values
188+
config._reset("STUMPY_FASTMATH_FLAGS")
189+
cache._recompile(
190+
core._calculate_squared_distance, fastmath=config.STUMPY_FASTMATH_FLAGS
219191
)
220-
npt.assert_almost_equal(ref_regimes, cmp_regimes_NOreassoc)
192+
cache._recompile()
221193

222194

223195
@pytest.mark.filterwarnings("ignore", category=NumbaPerformanceWarning)

0 commit comments

Comments
 (0)