Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Core & custom Jaxprs #62

Open
rlouf opened this issue Dec 5, 2020 · 0 comments
Open

Core & custom Jaxprs #62

rlouf opened this issue Dec 5, 2020 · 0 comments
Labels
core-compiler Non user-facing improvements discussion

Comments

@rlouf
Copy link
Owner

rlouf commented Dec 5, 2020

Here are a few ideas for the core that I won't have the time to do before the first release but have the potential to make MCX more general and greatly simplify the bijectors API.

First, a shortcoming of many PPLs is the impossibility of conditioning on deterministic transformations of random variables. This is because the logpdf function would need to propagate the inverse of the log-determinant of the jacobian matrix for volume conservation. This seems to be a job for Jaxprs. The idea would be to have the core compile the graph in a way that can be manipulated by JAX and create a "logpdf" Jaxpr that is applied on this function.

Then, if this works, we could only implement the "forward" part of bijectors. The logpdf Jaxpr would automatically take care of conserving the volume. Writing a Jaxpr that inverses the transformation is, if not easy, possible.

Graph --> JAX-ready logpdf --> logpdf Jaxpr
Graph --> Joint distribution forward sampler
Graph --> Predictive distribution sampler

Sampling and sampling predictive are simple enough that they can be left as is.

As a result we would have a two-layer core:

  • A "conceptual" graph that allows to reason about distributions directly. Which allows to identify conjugacy relationships that can be collapsed, transformations to apply to random variables with constrained support, the samplers that are best adapted to each variable, etc.
  • A computational layer that relies on custom JAX primitives to compute the logprob of transformed random variables, simplifies the computation when possible.
@rlouf rlouf added core-compiler Non user-facing improvements discussion labels Dec 5, 2020
@rlouf rlouf changed the title Core Core & custom Jaxprs Dec 5, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
core-compiler Non user-facing improvements discussion
Projects
None yet
Development

No branches or pull requests

1 participant