Replies: 1 comment
-
Hi - thanks for the question! This is working as expected: outside of JIT, everything is executed eagerly, and there is no cache for tracing functions. This means that when you execute a vmapped function that is not JIT compiled, it will be retraced each time. If this is not the behavior you want, I'd suggest wrapping the vmapped function in JIT to take advantage of the JIT cache. Does that make sense? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi, I would like to ask about the basic concepts of JAX. I’ve noticed that without
jit
, bothvmap
andgrad
seem to re-trace the function each time they are called, and I don’t fully understand why this happens.From what I understand, JAX traces the batched array to generate the computation graph. However, for an identical input shape passed in subsequent calls, it should be able to reuse the computation graph, rather than re-tracing it every time.
I believe the key difference between using
jit
and not using it is whether the computation graph is compiled. My assumption was that withoutjit
, the function would pass through an uncompiled version of the graph, while withjit
, it would pass through the compiled machine code. However, it seems like the function is traced each time, even withoutjit
.Output
Beta Was this translation helpful? Give feedback.
All reactions