Pipelined calculation involving scan
and pure_callback
#25232
Unanswered
mfschubert
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
This question is in regards to pipelined calculations involving
jax.lax.scan
andjax.pure_callback
. I have an expensive calculation with two parts:pure_callback
Both calculations are used in a scan operation, as shown in the example below.
I am hoping to pipeline the calculation to speed things up, as follows:
I would expect the compute time to be cut in half here, since the
slow_callback_fn(key)
anddummy_jax_fn(carry)
take an equal amount of time and can run independently. However, this doesn't seem to be the case in practice.Is this expected? Is there some other way I can force these two calculations to run in parallel?
Beta Was this translation helpful? Give feedback.
All reactions