You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Here are a few ideas for the core that I won't have the time to do before the first release but have the potential to make MCX more general and greatly simplify the bijectors API.
First, a shortcoming of many PPLs is the impossibility of conditioning on deterministic transformations of random variables. This is because the logpdf function would need to propagate the inverse of the log-determinant of the jacobian matrix for volume conservation. This seems to be a job for Jaxprs. The idea would be to have the core compile the graph in a way that can be manipulated by JAX and create a "logpdf" Jaxpr that is applied on this function.
Then, if this works, we could only implement the "forward" part of bijectors. The logpdf Jaxpr would automatically take care of conserving the volume. Writing a Jaxpr that inverses the transformation is, if not easy, possible.
Graph --> JAX-ready logpdf --> logpdf Jaxpr
Graph --> Joint distribution forward sampler
Graph --> Predictive distribution sampler
Sampling and sampling predictive are simple enough that they can be left as is.
As a result we would have a two-layer core:
A "conceptual" graph that allows to reason about distributions directly. Which allows to identify conjugacy relationships that can be collapsed, transformations to apply to random variables with constrained support, the samplers that are best adapted to each variable, etc.
A computational layer that relies on custom JAX primitives to compute the logprob of transformed random variables, simplifies the computation when possible.
The text was updated successfully, but these errors were encountered:
Here are a few ideas for the core that I won't have the time to do before the first release but have the potential to make MCX more general and greatly simplify the bijectors API.
First, a shortcoming of many PPLs is the impossibility of conditioning on deterministic transformations of random variables. This is because the logpdf function would need to propagate the inverse of the log-determinant of the jacobian matrix for volume conservation. This seems to be a job for Jaxprs. The idea would be to have the core compile the graph in a way that can be manipulated by JAX and create a "logpdf" Jaxpr that is applied on this function.
Then, if this works, we could only implement the "forward" part of bijectors. The logpdf Jaxpr would automatically take care of conserving the volume. Writing a Jaxpr that inverses the transformation is, if not easy, possible.
Sampling and sampling predictive are simple enough that they can be left as is.
As a result we would have a two-layer core:
The text was updated successfully, but these errors were encountered: