|
1 | 1 | import functools |
2 | | -import importlib |
3 | 2 | from unittest.mock import patch |
4 | 3 |
|
5 | 4 | import naive |
@@ -149,75 +148,48 @@ def test_snippets(): |
149 | 148 | cmp_regimes, |
150 | 149 | ) = stumpy.snippets(T, m, k, s=s, mpdist_T_subseq_isconstant=isconstant_custom_func) |
151 | 150 |
|
152 | | - if np.allclose(ref_snippets, cmp_snippets) or numba.config.DISABLE_JIT: |
153 | | - npt.assert_almost_equal( |
154 | | - ref_snippets, cmp_snippets, decimal=config.STUMPY_TEST_PRECISION |
155 | | - ) |
156 | | - npt.assert_almost_equal( |
157 | | - ref_indices, cmp_indices, decimal=config.STUMPY_TEST_PRECISION |
158 | | - ) |
159 | | - npt.assert_almost_equal( |
160 | | - ref_profiles, cmp_profiles, decimal=config.STUMPY_TEST_PRECISION |
161 | | - ) |
162 | | - npt.assert_almost_equal( |
163 | | - ref_fractions, cmp_fractions, decimal=config.STUMPY_TEST_PRECISION |
164 | | - ) |
165 | | - npt.assert_almost_equal( |
166 | | - ref_areas, cmp_areas, decimal=config.STUMPY_TEST_PRECISION |
167 | | - ) |
168 | | - npt.assert_almost_equal(ref_regimes, cmp_regimes) |
169 | | - else: |
170 | | - # Revise fastmath flag, recompile, and re-calculate snippets, |
171 | | - # and then revert the changes |
172 | | - |
| 151 | + if not np.allclose(ref_snippets, cmp_snippets) and not numba.config.DISABLE_JIT: |
| 152 | + # Revise fastmath flags by removing reassoc (to improve precision), |
| 153 | + # recompile njit functions, and re-compute snippets. |
173 | 154 | config.STUMPY_FASTMATH_FLAGS = {"nsz", "arcp", "contract", "afn"} |
174 | | - core._calculate_squared_distance.targetoptions["fastmath"] = ( |
175 | | - config.STUMPY_FASTMATH_FLAGS |
| 155 | + cache._recompile( |
| 156 | + core._calculate_squared_distance, fastmath=config.STUMPY_FASTMATH_FLAGS |
176 | 157 | ) |
177 | | - |
178 | | - njit_funcs = cache.get_njit_funcs() |
179 | | - for module_name, func_name in njit_funcs: |
180 | | - module = importlib.import_module(f".{module_name}", package="stumpy") |
181 | | - func = getattr(module, func_name) |
182 | | - func.recompile() |
| 158 | + cache._recompile() |
183 | 159 |
|
184 | 160 | ( |
185 | | - cmp_snippets_NOreassoc, |
186 | | - cmp_indices_NOreassoc, |
187 | | - cmp_profiles_NOreassoc, |
188 | | - cmp_fractions_NOreassoc, |
189 | | - cmp_areas_NOreassoc, |
190 | | - cmp_regimes_NOreassoc, |
| 161 | + cmp_snippets, |
| 162 | + cmp_indices, |
| 163 | + cmp_profiles, |
| 164 | + cmp_fractions, |
| 165 | + cmp_areas, |
| 166 | + cmp_regimes, |
191 | 167 | ) = stumpy.snippets( |
192 | 168 | T, m, k, s=s, mpdist_T_subseq_isconstant=isconstant_custom_func |
193 | 169 | ) |
194 | 170 |
|
195 | | - config._reset("STUMPY_FASTMATH_FLAGS") |
196 | | - |
197 | | - core._calculate_squared_distance.targetoptions["fastmath"] = ( |
198 | | - config.STUMPY_FASTMATH_FLAGS |
199 | | - ) |
200 | | - for module_name, func_name in njit_funcs: |
201 | | - module = importlib.import_module(f".{module_name}", package="stumpy") |
202 | | - func = getattr(module, func_name) |
203 | | - func.recompile() |
| 171 | + npt.assert_almost_equal( |
| 172 | + ref_snippets, cmp_snippets, decimal=config.STUMPY_TEST_PRECISION |
| 173 | + ) |
| 174 | + npt.assert_almost_equal( |
| 175 | + ref_indices, cmp_indices, decimal=config.STUMPY_TEST_PRECISION |
| 176 | + ) |
| 177 | + npt.assert_almost_equal( |
| 178 | + ref_profiles, cmp_profiles, decimal=config.STUMPY_TEST_PRECISION |
| 179 | + ) |
| 180 | + npt.assert_almost_equal( |
| 181 | + ref_fractions, cmp_fractions, decimal=config.STUMPY_TEST_PRECISION |
| 182 | + ) |
| 183 | + npt.assert_almost_equal(ref_areas, cmp_areas, decimal=config.STUMPY_TEST_PRECISION) |
| 184 | + npt.assert_almost_equal(ref_regimes, cmp_regimes) |
204 | 185 |
|
205 | | - npt.assert_almost_equal( |
206 | | - ref_snippets, cmp_snippets_NOreassoc, decimal=config.STUMPY_TEST_PRECISION |
207 | | - ) |
208 | | - npt.assert_almost_equal( |
209 | | - ref_indices, cmp_indices_NOreassoc, decimal=config.STUMPY_TEST_PRECISION |
210 | | - ) |
211 | | - npt.assert_almost_equal( |
212 | | - ref_profiles, cmp_profiles_NOreassoc, decimal=config.STUMPY_TEST_PRECISION |
213 | | - ) |
214 | | - npt.assert_almost_equal( |
215 | | - ref_fractions, cmp_fractions_NOreassoc, decimal=config.STUMPY_TEST_PRECISION |
216 | | - ) |
217 | | - npt.assert_almost_equal( |
218 | | - ref_areas, cmp_areas_NOreassoc, decimal=config.STUMPY_TEST_PRECISION |
| 186 | + if not numba.config.DISABLE_JIT: |
| 187 | + # Revert fastmath flag back to their default values |
| 188 | + config._reset("STUMPY_FASTMATH_FLAGS") |
| 189 | + cache._recompile( |
| 190 | + core._calculate_squared_distance, fastmath=config.STUMPY_FASTMATH_FLAGS |
219 | 191 | ) |
220 | | - npt.assert_almost_equal(ref_regimes, cmp_regimes_NOreassoc) |
| 192 | + cache._recompile() |
221 | 193 |
|
222 | 194 |
|
223 | 195 | @pytest.mark.filterwarnings("ignore", category=NumbaPerformanceWarning) |
|
0 commit comments