Stochastic support #97
rlouf
started this conversation in
New features
Replies: 1 comment
-
Note before I close that this is easily implemented in |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Programs with stochastic support (with control flow in them) are notoriously difficult to implement and to sample [1].
JAX cannot JIT-compile a function with python control flow (although it can apply
grad
to it), and instead requires a special construct,jax.lax.cond
orjax.lax.switch
. This notation is cumbersome as:Since MCX parses model expressions into a graphical model it does seem feasible to extract control flow, translate it into a graph structure which is then compiled into a JAX-compatible function.
We need to discuss:
Does not need to implement for v0.1 but API design should be clear before releasing as it will impact the way the graph is structured. Use this issue for discussions.
References
[1]: "Divide, Conquer, and Combine: a New Inference Strategy for Probabilistic Programs with Stochastic Support" https://arxiv.org/abs/1910.13324
Beta Was this translation helpful? Give feedback.
All reactions