-
Hi jax folks. I'd like to implement a variant of the pool-adjacent-violators algorithm (PAVA) in Jax. For context, PAVA solves the following simple optimization problem, which is typically called isotonic regression: min_x |x - y|_2^2 s.t x_1 >= x_2 >= ... >= x_N Here x and y are both vectors. In my application, there is an additional parameter \gamma, since we'd like to specify that x is not just decaying, but exponentially decaying min_x |x - y|_2^2 s.t x_1 * gamma >= x_2, x_2 * gamma >= x_3, etc There are efficient algorithms for solving these in O(N) time where N is the length of x and y. In my case, I'll be happy with a naive O(T^2) implementation. It's easy to code up this naive implementation in a few lines of python:
This works by sliding an index, However, it's not clear to me how to implement this backtracking approach in jax. For instance, the following code won't work because we can't slice dynamically inside the
It's not clear to me that there's a nice way of expressing this "backtrack until constraints are satisfied" idea in a JAX friendly way. I need to be able to JIT and VMAP this entire computation (otherwise I'd just use Julia!). Any thoughts here? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
This does not look like an algorithm that is well-suited to JAX's computation model, both due to the dynamic array shapes, and due to the use of nested control flow, which will not perform well on accelerators. That said, you can probably make it work by employing some tricks to express your dynamically-shaped computations in terms of statically-shaped arrays. For example, you could replace this: data_block = x[j:i+2]
numerator = np.sum(data_block * gamma**(np.arange(len(data_block)))) with something like this: ind = jnp.arange(len(x))
data_block = jnp.where((ind >= j) && (ind < i + 2), x, 0
numerator = jnp.sum(data_block * gamma ** (ind - i)) That should compute the same value (unless you run into overflow/underflow with |
Beta Was this translation helpful? Give feedback.
-
Did you manage to do it finally ? I am interesting to have a solution |
Beta Was this translation helpful? Give feedback.
This does not look like an algorithm that is well-suited to JAX's computation model, both due to the dynamic array shapes, and due to the use of nested control flow, which will not perform well on accelerators.
That said, you can probably make it work by employing some tricks to express your dynamically-shaped computations in terms of statically-shaped arrays.
For example, you could replace this:
with something like this:
That should compute the same value (unless yo…