From 1320fda48022c18b575ff083bbe1dc3d1e483999 Mon Sep 17 00:00:00 2001 From: Yuanqing Wang Date: Mon, 5 Feb 2024 23:06:28 -0500 Subject: [PATCH 1/9] initialize everything with log sigma --- modelforge/potential/bayesian_models.py | 27 +++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 modelforge/potential/bayesian_models.py diff --git a/modelforge/potential/bayesian_models.py b/modelforge/potential/bayesian_models.py new file mode 100644 index 00000000..3b91925b --- /dev/null +++ b/modelforge/potential/bayesian_models.py @@ -0,0 +1,27 @@ +import torch +import pyro +from pyro.nn.module import to_pyro_module_ + +def init_log_sigma(model, value): + """Initializes the log_sigma parameters of a model + + Parameters + ---------- + model : torch.nn.Module + The model to initialize + + value : float + The value to initialize the log_sigma parameters to + + """ + log_sigma_params = { + name + "_log_sigma": pyro.nn.Parameter( + torch.ones(param.shape) * value, + ) + for name, param in model.named_parameters() + } + + for name, param in log_sigma_params.items(): + setattr(model, name, param) + + From 9bb1d82b9e579cced5462735592c16aed817d098 Mon Sep 17 00:00:00 2001 From: Yuanqing Wang Date: Thu, 29 Feb 2024 23:20:18 -0500 Subject: [PATCH 2/9] add SVI --- modelforge/potential/bayesian_models.py | 20 ++++++++++++- modelforge/tests/test_bayesian_model.py | 38 +++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 modelforge/tests/test_bayesian_model.py diff --git a/modelforge/potential/bayesian_models.py b/modelforge/potential/bayesian_models.py index 3b91925b..af6bd572 100644 --- a/modelforge/potential/bayesian_models.py +++ b/modelforge/potential/bayesian_models.py @@ -15,7 +15,7 @@ def init_log_sigma(model, value): """ log_sigma_params = { - name + "_log_sigma": pyro.nn.Parameter( + name + "_log_sigma": pyro.nn.PyroParam( torch.ones(param.shape) * value, ) for name, param in model.named_parameters() @@ -24,4 +24,22 @@ def init_log_sigma(model, value): for name, param in log_sigma_params.items(): setattr(model, name, param) +class BayesianAutoNormalPotential(torch.nn.Module): + def __init__( + self, + *args, **kwargs, + ): + super().__init__() + log_sigma = kwargs.pop("log_sigma", 0.0) + init_log_sigma(self, log_sigma) + + def model(self, *args, **kwargs): + y = kwargs.pop("y", None) + y_hat = self(*args, **kwargs) + pyro.sample( + "obs", + pyro.distributions.Delta(y_hat), + obs=y + ) + diff --git a/modelforge/tests/test_bayesian_model.py b/modelforge/tests/test_bayesian_model.py new file mode 100644 index 00000000..b9f49b7e --- /dev/null +++ b/modelforge/tests/test_bayesian_model.py @@ -0,0 +1,38 @@ +import pyro +from modelforge.potential import CosineCutoff, GaussianRBF +from modelforge.potential.utils import SlicedEmbedding +from modelforge.potential.schnet import SchNET +from modelforge.potential.bayesian_models import BayesianAutoNormalPotential +from .helper_functions import SIMPLIFIED_INPUT_DATA + +def test_bayesian_model(): + # initialize a vanilla SchNet model + embedding = SlicedEmbedding(8, 16, sliced_dim=0) + rbf = GaussianRBF(n_rbf=8, cutoff=5.0) + cutoff = CosineCutoff(5.0) + schnet = SchNET( + embedding=embedding, + cutoff=cutoff, + nr_interaction_blocks=8, + radial_basis=rbf, + ) + + # make a Bayesian model from the SchNet + schnet = BayesianAutoNormalPotential(schnet, log_sigma=1e-2) + guide = pyro.infer.autoguide.AutoDiagonalNormal(schnet) + assert guide is not None + + # run SVI using the Bayesian model + svi = pyro.infer.SVI( + model=schnet, + guide=guide, + optim=pyro.optim.Adam({"lr": 1e-3}), + loss=pyro.infer.Trace_ELBO(), + ) + + # calculate VI loss + svi.step(SIMPLIFIED_INPUT_DATA, y=0.0) + + + + From fee441ba84b691344f62e5fa27e5ae8e3b16694c Mon Sep 17 00:00:00 2001 From: Yuanqing Wang Date: Fri, 1 Mar 2024 09:58:41 -0500 Subject: [PATCH 3/9] add documentation --- modelforge/potential/bayesian_models.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modelforge/potential/bayesian_models.py b/modelforge/potential/bayesian_models.py index af6bd572..6e65cab3 100644 --- a/modelforge/potential/bayesian_models.py +++ b/modelforge/potential/bayesian_models.py @@ -34,6 +34,9 @@ def __init__( init_log_sigma(self, log_sigma) def model(self, *args, **kwargs): + """The model function. If no `y` argument is provided, + provide the prior; if `y` is provided, provide the likelihood. + """ y = kwargs.pop("y", None) y_hat = self(*args, **kwargs) pyro.sample( From 59509edc2ac3ae2cd6b557f35829d0e1eebbd2a4 Mon Sep 17 00:00:00 2001 From: Yuanqing Wang Date: Thu, 14 Mar 2024 23:02:05 -0400 Subject: [PATCH 4/9] add pyro as dependency --- devtools/conda-envs/test_env.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index 5dc170fb..9c0c2ef4 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -36,3 +36,4 @@ dependencies: # pip installs - pip: - schnetpack>=2.0.0 + - pyro-ppl From 3232c56b018c40048de76790be126689b6a557d6 Mon Sep 17 00:00:00 2001 From: Yuanqing Wang Date: Fri, 15 Mar 2024 00:13:40 -0400 Subject: [PATCH 5/9] docstring --- modelforge/potential/bayesian_models.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/modelforge/potential/bayesian_models.py b/modelforge/potential/bayesian_models.py index 6e65cab3..2a6f285b 100644 --- a/modelforge/potential/bayesian_models.py +++ b/modelforge/potential/bayesian_models.py @@ -25,6 +25,19 @@ def init_log_sigma(model, value): setattr(model, name, param) class BayesianAutoNormalPotential(torch.nn.Module): + """A Bayesian model with a normal prior and likelihood. + + Parameters + ---------- + log_sigma : float, optional + The initial value of the log_sigma parameters. Default is 0.0. + + Methods + ------- + model + The model function. If no `y` argument is provided, + provide the prior; if `y` is provided, provide the likelihood. + """ def __init__( self, *args, **kwargs, From 9f4e891d7eea361dd316762a9c87747db72eb24c Mon Sep 17 00:00:00 2001 From: wiederm Date: Mon, 18 Mar 2024 22:34:07 +0100 Subject: [PATCH 6/9] Refactor test_bayesian_model.py: Simplify SchNet initialization --- modelforge/tests/test_bayesian_model.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/modelforge/tests/test_bayesian_model.py b/modelforge/tests/test_bayesian_model.py index b9f49b7e..1e1d0722 100644 --- a/modelforge/tests/test_bayesian_model.py +++ b/modelforge/tests/test_bayesian_model.py @@ -1,21 +1,12 @@ import pyro -from modelforge.potential import CosineCutoff, GaussianRBF -from modelforge.potential.utils import SlicedEmbedding -from modelforge.potential.schnet import SchNET +from modelforge.potential import SchNet from modelforge.potential.bayesian_models import BayesianAutoNormalPotential from .helper_functions import SIMPLIFIED_INPUT_DATA + def test_bayesian_model(): # initialize a vanilla SchNet model - embedding = SlicedEmbedding(8, 16, sliced_dim=0) - rbf = GaussianRBF(n_rbf=8, cutoff=5.0) - cutoff = CosineCutoff(5.0) - schnet = SchNET( - embedding=embedding, - cutoff=cutoff, - nr_interaction_blocks=8, - radial_basis=rbf, - ) + schnet = SchNet() # make a Bayesian model from the SchNet schnet = BayesianAutoNormalPotential(schnet, log_sigma=1e-2) @@ -32,7 +23,3 @@ def test_bayesian_model(): # calculate VI loss svi.step(SIMPLIFIED_INPUT_DATA, y=0.0) - - - - From 1bff29ee38e8da3e652c3ec71a3a4855e2639740 Mon Sep 17 00:00:00 2001 From: Yuanqing Wang Date: Thu, 21 Mar 2024 22:25:05 -0400 Subject: [PATCH 7/9] test --- modelforge/tests/test_bayesian_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modelforge/tests/test_bayesian_model.py b/modelforge/tests/test_bayesian_model.py index 1e1d0722..f6cc21de 100644 --- a/modelforge/tests/test_bayesian_model.py +++ b/modelforge/tests/test_bayesian_model.py @@ -1,10 +1,11 @@ +import pytest import pyro from modelforge.potential import SchNet from modelforge.potential.bayesian_models import BayesianAutoNormalPotential from .helper_functions import SIMPLIFIED_INPUT_DATA - -def test_bayesian_model(): +@pytest.mark.parametrize("input_data", SIMPLIFIED_INPUT_DATA) +def test_bayesian_model(input_data): # initialize a vanilla SchNet model schnet = SchNet() @@ -22,4 +23,4 @@ def test_bayesian_model(): ) # calculate VI loss - svi.step(SIMPLIFIED_INPUT_DATA, y=0.0) + svi.step(input_data, y=0.0) From 300ef35259695d35cbffc200421fb36115b3c995 Mon Sep 17 00:00:00 2001 From: Yuanqing Wang Date: Thu, 21 Mar 2024 23:17:50 -0400 Subject: [PATCH 8/9] bayesian models --- modelforge/tests/ase.toml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 modelforge/tests/ase.toml diff --git a/modelforge/tests/ase.toml b/modelforge/tests/ase.toml new file mode 100644 index 00000000..bbe625e8 --- /dev/null +++ b/modelforge/tests/ase.toml @@ -0,0 +1,5 @@ +1 = -1313.4668615546 +6 = -99366.70745535441 +7 = -143309.9379722722 +8 = -197082.0671774158 +9 = -261811.54555874597 From 51e9d52c93d13bebb5cbf83a7f7d6cc8a070356e Mon Sep 17 00:00:00 2001 From: Yuanqing Wang Date: Thu, 21 Mar 2024 23:17:51 -0400 Subject: [PATCH 9/9] Bayesian model and test --- modelforge/potential/bayesian_models.py | 39 ++++++++++++++++++------- modelforge/tests/test_bayesian_model.py | 2 +- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/modelforge/potential/bayesian_models.py b/modelforge/potential/bayesian_models.py index 2a6f285b..768f12ac 100644 --- a/modelforge/potential/bayesian_models.py +++ b/modelforge/potential/bayesian_models.py @@ -2,6 +2,19 @@ import pyro from pyro.nn.module import to_pyro_module_ +import functools + +def rsetattr(obj, attr, val): + pre, _, post = attr.rpartition('.') + return setattr(rgetattr(obj, pre) if pre else obj, post, val) + +# using wonder's beautiful simplification: https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects/31174427?noredirect=1#comment86638618_31174427 + +def rgetattr(obj, attr, *args): + def _getattr(obj, attr): + return getattr(obj, attr, *args) + return functools.reduce(_getattr, [obj] + attr.split('.')) + def init_log_sigma(model, value): """Initializes the log_sigma parameters of a model @@ -14,17 +27,21 @@ def init_log_sigma(model, value): The value to initialize the log_sigma parameters to """ - log_sigma_params = { - name + "_log_sigma": pyro.nn.PyroParam( - torch.ones(param.shape) * value, + params = { + name: pyro.nn.PyroSample( + pyro.distributions.Normal( + torch.zeros(param.shape), + torch.ones(param.shape) * value, + ) ) for name, param in model.named_parameters() } - for name, param in log_sigma_params.items(): - setattr(model, name, param) + for name, param in model.named_parameters(): + rsetattr(model, name, params[name]) + -class BayesianAutoNormalPotential(torch.nn.Module): +class BayesianAutoNormalPotential(pyro.nn.PyroModule): """A Bayesian model with a normal prior and likelihood. Parameters @@ -39,19 +56,21 @@ class BayesianAutoNormalPotential(torch.nn.Module): provide the prior; if `y` is provided, provide the likelihood. """ def __init__( - self, + self, base_model, *args, **kwargs, ): super().__init__() + to_pyro_module_(base_model) + self.base_model = base_model log_sigma = kwargs.pop("log_sigma", 0.0) - init_log_sigma(self, log_sigma) + init_log_sigma(self.base_model, log_sigma) - def model(self, *args, **kwargs): + def forward(self, *args, **kwargs): """The model function. If no `y` argument is provided, provide the prior; if `y` is provided, provide the likelihood. """ y = kwargs.pop("y", None) - y_hat = self(*args, **kwargs) + y_hat = self.base_model(*args, **kwargs)["E_predict"] pyro.sample( "obs", pyro.distributions.Delta(y_hat), diff --git a/modelforge/tests/test_bayesian_model.py b/modelforge/tests/test_bayesian_model.py index f6cc21de..88c37669 100644 --- a/modelforge/tests/test_bayesian_model.py +++ b/modelforge/tests/test_bayesian_model.py @@ -10,7 +10,7 @@ def test_bayesian_model(input_data): schnet = SchNet() # make a Bayesian model from the SchNet - schnet = BayesianAutoNormalPotential(schnet, log_sigma=1e-2) + schnet = BayesianAutoNormalPotential(schnet, log_sigma=1e-2).forward guide = pyro.infer.autoguide.AutoDiagonalNormal(schnet) assert guide is not None