diff --git a/src/jax_finufft/ops.py b/src/jax_finufft/ops.py index 4721c65..65850bc 100644 --- a/src/jax_finufft/ops.py +++ b/src/jax_finufft/ops.py @@ -60,6 +60,22 @@ def nufft2(source, *points, iflag=-1, eps=1e-6, opts=None): return index.unflatten(result) +def get_frequency_array(n, modeord): + if modeord == 0: + return np.arange(-np.floor(n / 2), np.floor((n - 1) / 2 + 1)) + elif modeord == 1: + if n % 2 == 0: + pos = np.arange(0, n // 2) + neg = np.arange(-n // 2, 0) + return np.concatenate([pos, neg]) + else: + pos = np.arange(0, (n + 1) // 2) + neg = np.arange(-(n - 1) // 2, 0) + return np.concatenate([pos, neg]) + else: + raise ValueError(f"Unsupported modeord: {modeord}") + + def jvp(prim, args, tangents, *, output_shape, iflag, eps, opts): # Type 1: # f_k = sum_j c_j * exp(iflag * i * k * x_j) @@ -75,6 +91,9 @@ def jvp(prim, args, tangents, *, output_shape, iflag, eps, opts): source, *points, output_shape=output_shape, iflag=iflag, eps=eps, opts=opts ) + # Extract modeord from opts + modeord = opts.modeord if hasattr(opts, "modeord") else 0 + # The JVP op can be written as a single transform of the same type with output_tangents = [] ndim = len(points) @@ -105,7 +124,7 @@ def jvp(prim, args, tangents, *, output_shape, iflag, eps, opts): n = source.shape[-ndim + dim] if output_shape is None else output_shape[dim] shape = np.ones(ndim, dtype=int) shape[dim] = -1 - k = np.arange(-np.floor(n / 2), np.floor((n - 1) / 2 + 1)) + k = get_frequency_array(n, modeord) k = k.reshape(shape) factor = 1j * iflag * k dx = dx[:, None, :]