From 950bc3560ff4fe75cbedcac13eb72890e011154f Mon Sep 17 00:00:00 2001 From: Tom Hilder Date: Wed, 20 Aug 2025 17:55:47 +1000 Subject: [PATCH 1/2] fix for different mode orderings in jvp --- src/jax_finufft/ops.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/src/jax_finufft/ops.py b/src/jax_finufft/ops.py index 4721c65..22155bc 100644 --- a/src/jax_finufft/ops.py +++ b/src/jax_finufft/ops.py @@ -2,14 +2,13 @@ from functools import partial, reduce -import numpy as np import jax -from jax import jit -from jax import numpy as jnp -from jax.interpreters import ad, batching, xla, mlir +import numpy as np +from jax import jit, numpy as jnp from jax.extend.core import Primitive +from jax.interpreters import ad, batching, mlir, xla -from jax_finufft import shapes, lowering, options +from jax_finufft import lowering, options, shapes @partial(jit, static_argnums=(0,), static_argnames=("iflag", "eps", "opts")) @@ -60,6 +59,22 @@ def nufft2(source, *points, iflag=-1, eps=1e-6, opts=None): return index.unflatten(result) +def get_frequency_array(n, modeord): + if modeord == 0: + return np.arange(-np.floor(n / 2), np.floor((n - 1) / 2 + 1)) + elif modeord == 1: + if n % 2 == 0: + pos = np.arange(0, n // 2) + neg = np.arange(-n // 2, 0) + return np.concatenate([pos, neg]) + else: + pos = np.arange(0, (n + 1) // 2) + neg = np.arange(-(n - 1) // 2, 0) + return np.concatenate([pos, neg]) + else: + raise ValueError(f"Unsupported modeord: {modeord}") + + def jvp(prim, args, tangents, *, output_shape, iflag, eps, opts): # Type 1: # f_k = sum_j c_j * exp(iflag * i * k * x_j) @@ -75,6 +90,9 @@ def jvp(prim, args, tangents, *, output_shape, iflag, eps, opts): source, *points, output_shape=output_shape, iflag=iflag, eps=eps, opts=opts ) + # Extract modeord from opts + modeord = opts.modeord if hasattr(opts, "modeord") else 0 + # The JVP op can be written as a single transform of the same type with output_tangents = [] ndim = len(points) @@ -105,7 +123,7 @@ def jvp(prim, args, tangents, *, output_shape, iflag, eps, opts): n = source.shape[-ndim + dim] if output_shape is None else output_shape[dim] shape = np.ones(ndim, dtype=int) shape[dim] = -1 - k = np.arange(-np.floor(n / 2), np.floor((n - 1) / 2 + 1)) + k = get_frequency_array(n, modeord) k = k.reshape(shape) factor = 1j * iflag * k dx = dx[:, None, :] From a7bf79459dd3b833f0575a25af0891fb85afe294 Mon Sep 17 00:00:00 2001 From: Tom Hilder Date: Wed, 20 Aug 2025 17:58:07 +1000 Subject: [PATCH 2/2] undo changes to inputs from my formatter --- src/jax_finufft/ops.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/jax_finufft/ops.py b/src/jax_finufft/ops.py index 22155bc..65850bc 100644 --- a/src/jax_finufft/ops.py +++ b/src/jax_finufft/ops.py @@ -2,13 +2,14 @@ from functools import partial, reduce -import jax import numpy as np -from jax import jit, numpy as jnp +import jax +from jax import jit +from jax import numpy as jnp +from jax.interpreters import ad, batching, xla, mlir from jax.extend.core import Primitive -from jax.interpreters import ad, batching, mlir, xla -from jax_finufft import lowering, options, shapes +from jax_finufft import shapes, lowering, options @partial(jit, static_argnums=(0,), static_argnames=("iflag", "eps", "opts"))