Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autodiff likelihoods from lsbi #43

Open
yallup opened this issue Jul 23, 2024 · 2 comments
Open

Autodiff likelihoods from lsbi #43

yallup opened this issue Jul 23, 2024 · 2 comments

Comments

@yallup
Copy link
Collaborator

yallup commented Jul 23, 2024

It would be useful for testing numerical inference algorithms to have differentiable likelihoods in lsbi. In theory I think the whole package can swap to jax, however things like rng are quite different and would require some excavation, links to #41.

The basic thing one needs is the ability to furnish the distributions with a jax log_prob function. The most useful would be the likelihood, this can be done fairly simply below.

from lsbi.model import MixtureModel, LinearModel
from jax.scipy.stats import multivariate_normal
import jax.numpy as jnp
import numpy as np

d = 100
t = 5
C = np.eye(d) * 50
model = LinearModel(M=np.random.randn(d, t))
true_theta, true_data = np.split(model.joint().rvs(), [t], axis=-1)


def log_prob(theta):
    mu = model.m + jnp.einsum(
        "...ja,...a->...j", model._M, theta * jnp.ones(model.n)
    )
    return multivariate_normal.logpdf(true_data, mean=mu, cov=model._C)


from jax import random
from jax import vmap, value_and_grad

rng = random.PRNGKey(0)

theta_samples = random.normal(rng, (100, t))
np_log_prob = model.likelihood(theta_samples).logpdf(true_data)
jax_log_prob = log_prob(theta_samples)

value, grad = vmap(value_and_grad(log_prob))(theta_samples)


print((np_log_prob - jax_log_prob).mean())
print((np_log_prob - value).mean())

and for basic mixtures

from lsbi.model import MixtureModel, LinearModel
from jax.scipy.stats import multivariate_normal
import jax.numpy as jnp
import numpy as np
from jax.scipy.special import logsumexp

d = 100
t = 5
k = 3
C = np.eye(d) * 50
# model = LinearModel(M=np.random.randn(d, t))
model = MixtureModel(M=np.random.randn(k, d, t))
true_theta, true_data = np.split(model.joint().rvs(), [t], axis=-1)


def log_prob(theta):
    mu = model.m + jnp.einsum(
        "...ja,...a->...j",
        model._M,
        jnp.expand_dims(theta, -2) * jnp.ones(model.n),
    )
    mixture_weights = logsumexp(model.logw * jnp.ones(model.k))
    # return logsumexp(multivariate_normal.logpdf(theta, mean=mu, cov=model._C))
    return (
        logsumexp(
            multivariate_normal.logpdf(
                true_data, mean=mu, cov=model._C
            ),
            axis=-1,
        )
        - mixture_weights
    )


from jax import random
from jax import vmap, value_and_grad

rng = random.PRNGKey(0)

theta_samples = random.normal(rng, (100, t))
np_log_prob = model.likelihood(theta_samples).logpdf(true_data)
jax_log_prob = log_prob(theta_samples)

value, grad = vmap(value_and_grad(log_prob))(theta_samples)


print((np_log_prob - jax_log_prob).mean())
print((np_log_prob - value).mean())

Not sure if this can be elegantly integrated but I will put this here for now as potentially useful for other projects

nb: correct weighting for mixtures with non trivial weights is wrong here, to be fixed later

@williamjameshandley
Copy link
Contributor

The other option here is to have analytic gradients (and hessians) -- I don't know if this would be less flexible/faster or slower?

@yallup
Copy link
Collaborator Author

yallup commented Jul 24, 2024

Good point! probably better and fits the ethos more, I will say this is not expensive and relatively easy to modify, so until we know what we actually want to optimize/sample, this is probably sufficient.

Below example fitting a model matrix from a single joint observation

Maximum Likelihood
model_opt

Maximum Evidence
model_opt

from lsbi.model import MixtureModel, LinearModel
from jax.scipy.stats import multivariate_normal
import jax.numpy as jnp
import numpy as np
from jax.scipy.special import logsumexp
import anesthetic as ns
import matplotlib.pyplot as plt

d = 100
t = 5
k = 3
C = np.eye(d) * 50
model = LinearModel(M=np.random.randn(d, t))
# model = MixtureModel(M=np.random.randn(k, d, t))
true_theta, true_data = np.split(model.joint().rvs(), [t], axis=-1)


def log_prob(theta_m):
    #evidence
    # mu = model.m + jnp.einsum(
    #     "...ja,...a->...j", theta_m, true_theta * jnp.ones(model.n)
    # )
    # Σ = model._C + jnp.einsum(
    #             "...ja,...ab,...kb->...jk", theta_m, model._Σ, theta_m
    #         )
    # return multivariate_normal.logpdf(true_data, mean=mu, cov=Σ)

    #likelihood
    mu = model.m + jnp.einsum(
        "...ja,...a->...j", theta_m, true_theta * jnp.ones(model.n)
    )
    return  - multivariate_normal.logpdf(true_data, mean=mu, cov=model._C)

from jax import random
from jax import vmap, value_and_grad, jit
import optax
from jaxopt import LBFGS
rng = random.PRNGKey(0)

theta_m_samples = random.normal(rng, (d, t))
# np_log_prob = model.likelihood(theta_m_samples).logpdf(true_data)
jax_log_prob = log_prob(theta_m_samples)

# value, grad = vmap(value_and_grad(log_prob))(theta_m_samples)


theta_m = random.normal(rng, (d, t))
steps = 1000
# optimizer = optax.adam(1)
# opt_state = optimizer.init(theta_m)
solver = LBFGS(jit(log_prob), maxiter=steps)
# losses = []
# for i in range(steps):
#     value, grad = jit(value_and_grad(log_prob))(theta_m)
#     updates, optimizer_state = optimizer.update(grad, opt_state)
#     theta_m = optax.apply_updates(theta_m, updates)
#     losses.append(value)
#     print(value)

res = solver.run(theta_m)

surrogate_model = LinearModel(M=res[0])

a = ns.MCMCSamples(surrogate_model.posterior(true_data).rvs(500)).plot_2d(figsize=(6,6), label = "Fitted Surrogate Posterior")
ns.MCMCSamples(model.posterior(true_data).rvs(500)).plot_2d(a, label = "True Posterior")
a.iloc[-1, 0].legend(
    loc="lower center",
    bbox_to_anchor=(len(a) / 2, len(a)),
)
plt.savefig("model_opt.pdf")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants