diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a9995706..5a2129e6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -53,6 +53,7 @@ jobs: run: | pip install pytest pip install torch torchvision + pip install numba pytest tests # - name: Install from dist diff --git a/tests/equi2pers/numpy_inv.py b/tests/equi2pers/numpy_inv.py index 6c49f9e9..3d97beea 100644 --- a/tests/equi2pers/numpy_inv.py +++ b/tests/equi2pers/numpy_inv.py @@ -5,11 +5,7 @@ from timeit import timeit -try: - from numba import jit -except ImportError: - print("numba not available") - jit = None +from numba import jit import numpy as np @@ -59,7 +55,7 @@ def hdinv(A): return invA -# @jit("float64[:,:](float64[:,:])", cache=True, nopython=True, nogil=True) +@jit("float64[:,:](float64[:,:])", cache=True, nopython=True, nogil=True) def fast_inverse(A): inv = np.empty_like(A) a = A[0, 0] @@ -89,7 +85,7 @@ def fast_inverse(A): return inv -# @jit(cache=True, nopython=True, nogil=True) +@jit(cache=True, nopython=True, nogil=True) def vecinv(A): invA = np.zeros_like(A) for i in range(A.shape[0]): diff --git a/tests/grid_sample/numpy/nearest.py b/tests/grid_sample/numpy/nearest.py index 4eac5ecd..72c7fbbb 100644 --- a/tests/grid_sample/numpy/nearest.py +++ b/tests/grid_sample/numpy/nearest.py @@ -4,11 +4,7 @@ """ -try: - from numba import njit -except ImportError: - print("numba not available") - njit = None +from numba import njit import numpy as np @@ -80,7 +76,7 @@ def faster_nearest( return out -# @njit +@njit def run(img, grid, out, b, h, w): for i in range(b): for y_out in range(h):