Skip to content

Implementing the Pool-adjacent-violators algorithm (PAVA) in jax #14406

Answered by jakevdp
bantin asked this question in Q&A
Discussion options

You must be logged in to vote

This does not look like an algorithm that is well-suited to JAX's computation model, both due to the dynamic array shapes, and due to the use of nested control flow, which will not perform well on accelerators.

That said, you can probably make it work by employing some tricks to express your dynamically-shaped computations in terms of statically-shaped arrays.

For example, you could replace this:

data_block = x[j:i+2]
numerator = np.sum(data_block * gamma**(np.arange(len(data_block))))

with something like this:

ind = jnp.arange(len(x))
data_block = jnp.where((ind >= j) && (ind < i + 2), x, 0
numerator = jnp.sum(data_block * gamma ** (ind - i))

That should compute the same value (unless yo…

Replies: 2 comments 3 replies

Comment options

You must be logged in to vote
3 replies
@bantin
Comment options

@jakevdp
Comment options

@bantin
Comment options

Answer selected by bantin
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants