Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uncertainty aware models #66

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
39 changes: 29 additions & 10 deletions modelforge/potential/bayesian_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@
import pyro
from pyro.nn.module import to_pyro_module_

import functools

def rsetattr(obj, attr, val):
pre, _, post = attr.rpartition('.')
return setattr(rgetattr(obj, pre) if pre else obj, post, val)

# using wonder's beautiful simplification: https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects/31174427?noredirect=1#comment86638618_31174427

def rgetattr(obj, attr, *args):
def _getattr(obj, attr):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))

def init_log_sigma(model, value):
"""Initializes the log_sigma parameters of a model

Expand All @@ -14,17 +27,21 @@ def init_log_sigma(model, value):
The value to initialize the log_sigma parameters to

"""
log_sigma_params = {
name + "_log_sigma": pyro.nn.PyroParam(
torch.ones(param.shape) * value,
params = {
name: pyro.nn.PyroSample(
pyro.distributions.Normal(
torch.zeros(param.shape),
torch.ones(param.shape) * value,
)
)
for name, param in model.named_parameters()
}

for name, param in log_sigma_params.items():
setattr(model, name, param)
for name, param in model.named_parameters():
rsetattr(model, name, params[name])


class BayesianAutoNormalPotential(torch.nn.Module):
class BayesianAutoNormalPotential(pyro.nn.PyroModule):
"""A Bayesian model with a normal prior and likelihood.

Parameters
Expand All @@ -39,19 +56,21 @@ class BayesianAutoNormalPotential(torch.nn.Module):
provide the prior; if `y` is provided, provide the likelihood.
"""
def __init__(
self,
self, base_model,
*args, **kwargs,
):
super().__init__()
to_pyro_module_(base_model)
self.base_model = base_model
log_sigma = kwargs.pop("log_sigma", 0.0)
init_log_sigma(self, log_sigma)
init_log_sigma(self.base_model, log_sigma)

def model(self, *args, **kwargs):
def forward(self, *args, **kwargs):
"""The model function. If no `y` argument is provided,
provide the prior; if `y` is provided, provide the likelihood.
"""
y = kwargs.pop("y", None)
y_hat = self(*args, **kwargs)
y_hat = self.base_model(*args, **kwargs).E
pyro.sample(
"obs",
pyro.distributions.Delta(y_hat),
Expand Down
5 changes: 5 additions & 0 deletions modelforge/tests/ase.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
1 = -1313.4668615546
6 = -99366.70745535441
7 = -143309.9379722722
8 = -197082.0671774158
9 = -261811.54555874597
26 changes: 26 additions & 0 deletions modelforge/tests/test_bayesian_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest
import pyro
from modelforge.potential import SchNet
from modelforge.potential.bayesian_models import BayesianAutoNormalPotential
from .helper_functions import SIMPLIFIED_INPUT_DATA

@pytest.mark.parametrize("input_data", SIMPLIFIED_INPUT_DATA)
def test_bayesian_model(input_data):
# initialize a vanilla SchNet model
schnet = SchNet()

# make a Bayesian model from the SchNet
schnet = BayesianAutoNormalPotential(schnet, log_sigma=1e-2).forward
guide = pyro.infer.autoguide.AutoDiagonalNormal(schnet)
assert guide is not None

# run SVI using the Bayesian model
svi = pyro.infer.SVI(
model=schnet,
guide=guide,
optim=pyro.optim.Adam({"lr": 1e-3}),
loss=pyro.infer.Trace_ELBO(),
)

# calculate VI loss
svi.step(input_data, y=0.0)
Loading