You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Found a bug in your code for cpu vs gpu. The bug is shown in the jupyter notebook below. If you want to code feel free to contact me. I also provided a solution which is not perfect but maybe we can work this out together.
Load in your Model: Bayesian Linear
import torch
from torch import nn
from torch.nn import functional as F
from blitz.modules.base_bayesian_module import BayesianModule
from blitz.modules.weight_sampler import TrainableRandomDistribution, PriorWeightDistribution
class BayesianLinear(BayesianModule):
"""
Bayesian Linear layer, implements the linear layer proposed on Weight Uncertainity on Neural Networks
(Bayes by Backprop paper).
Its objective is be interactable with torch nn.Module API, being able even to be chained in nn.Sequential models with other non-this-lib layers
parameters:
in_fetaures: int -> incoming features for the layer
out_features: int -> output features for the layer
bias: bool -> whether the bias will exist (True) or set to zero (False)
prior_sigma_1: float -> prior sigma on the mixture prior distribution 1
prior_sigma_2: float -> prior sigma on the mixture prior distribution 2
prior_pi: float -> pi on the scaled mixture prior
posterior_mu_init float -> posterior mean for the weight mu init
posterior_rho_init float -> posterior mean for the weight rho init
freeze: bool -> wheter the model will start with frozen(deterministic) weights, or not
"""
def __init__(self,
in_features,
out_features,
bias=True,
prior_sigma_1 = 0.1,
prior_sigma_2 = 0.4,
prior_pi = 1,
posterior_mu_init = 0,
posterior_rho_init = -7.0,
freeze = False,
prior_dist = None):
super().__init__()
#our main parameters
self.in_features = in_features
self.out_features = out_features
self.bias = bias
self.freeze = freeze
self.posterior_mu_init = posterior_mu_init
self.posterior_rho_init = posterior_rho_init
#parameters for the scale mixture prior
self.prior_sigma_1 = prior_sigma_1
self.prior_sigma_2 = prior_sigma_2
self.prior_pi = prior_pi
self.prior_dist = prior_dist
# Variational weight parameters and sample
self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features).normal_(posterior_mu_init, 0.1))
self.weight_rho = nn.Parameter(torch.Tensor(out_features, in_features).normal_(posterior_rho_init, 0.1))
self.weight_sampler = TrainableRandomDistribution(self.weight_mu, self.weight_rho)
# Variational bias parameters and sample
self.bias_mu = nn.Parameter(torch.Tensor(out_features).normal_(posterior_mu_init, 0.1))
self.bias_rho = nn.Parameter(torch.Tensor(out_features).normal_(posterior_rho_init, 0.1))
self.bias_sampler = TrainableRandomDistribution(self.bias_mu, self.bias_rho)
# Priors (as BBP paper)
self.weight_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist = self.prior_dist)
self.bias_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist = self.prior_dist)
self.log_prior = 0
self.log_variational_posterior = 0
def forward(self, x):
# Sample the weights and forward it
#if the model is frozen, return frozen
if self.freeze:
return self.forward_frozen(x)
w = self.weight_sampler.sample()
if self.bias:
b = self.bias_sampler.sample()
b_log_posterior = self.bias_sampler.log_posterior()
b_log_prior = self.bias_prior_dist.log_prior(b)
else:
b = torch.zeros((self.out_features), device=x.device)
b_log_posterior = 0
b_log_prior = 0
# Get the complexity cost
self.log_variational_posterior = self.weight_sampler.log_posterior() + b_log_posterior
self.log_prior = self.weight_prior_dist.log_prior(w) + b_log_prior
return F.linear(x, w, b)
def forward_frozen(self, x):
"""
Computes the feedforward operation with the expected value for weight and biases
"""
if self.bias:
return F.linear(x, self.weight_mu, self.bias_mu)
else:
return F.linear(x, self.weight_mu, torch.zeros(self.out_features))
Import the example of your library (once all on cpu once all on gpu)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
#from blitz.modules import BayesianLinear
from blitz.utils import variational_estimator
from sklearn.datasets import fetch_california_housing
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
X, y = fetch_california_housing(return_X_y=True)
X = StandardScaler().fit_transform(X)
y = StandardScaler().fit_transform(np.expand_dims(y, -1))
X_train, X_test, y_train, y_test = train_test_split(X,
y,
test_size=.25,
random_state=42)
device_gpu = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device_cpu = "cpu"
X_train_cpu, y_train_cpu = torch.tensor(X_train).float().to(device_cpu), torch.tensor(y_train).float().to(device_cpu)
X_test_cpu, y_test_cpu = torch.tensor(X_test).float().to(device_cpu), torch.tensor(y_test).float().to(device_cpu)
X_train_gpu, y_train_gpu = torch.tensor(X_train).float().to(device_gpu), torch.tensor(y_train).float().to(device_gpu)
X_test_gpu, y_test_gpu = torch.tensor(X_test).float().to(device_gpu), torch.tensor(y_test).float().to(device_gpu)
@variational_estimator
class BayesianRegressor(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.linear = nn.Linear(input_dim, 128)
self.blinear1 = BayesianLinear(128, 512)#.to(device) #, device=device
self.blinear2 = BayesianLinear(512, output_dim)
def forward(self, x):
x_ = self.linear(x)
x_ = self.blinear1(x_)
x_ = F.relu(x_)
return self.blinear2(x_)
def evaluate_regression(regressor,
X,
y,
samples = 100,
std_multiplier = 2):
preds = [regressor(X) for i in range(samples)]
preds = torch.stack(preds)
means = preds.mean(axis=0)
stds = preds.std(axis=0)
ci_upper = means + (std_multiplier * stds)
ci_lower = means - (std_multiplier * stds)
ic_acc = (ci_lower <= y) * (ci_upper >= y)
ic_acc = ic_acc.float().mean()
return ic_acc, (ci_upper >= y).float().mean(), (ci_lower <= y).float().mean()
There is a problem with the gpu -> somehow it does not work compared to the cpu
New adapted BayesianLinear:
import torch
from torch import nn
from torch.nn import functional as F
from blitz.modules.base_bayesian_module import BayesianModule
from blitz.modules.weight_sampler import TrainableRandomDistribution, PriorWeightDistribution
from torch.nn.parameter import Parameter, UninitializedParameter
from torch.nn import init
class BayesianLinear(BayesianModule):
"""
Bayesian Linear layer, implements the linear layer proposed on Weight Uncertainity on Neural Networks
(Bayes by Backprop paper).
Its objective is be interactable with torch nn.Module API, being able even to be chained in nn.Sequential models with other non-this-lib layers
parameters:
in_fetaures: int -> incoming features for the layer
out_features: int -> output features for the layer
bias: bool -> whether the bias will exist (True) or set to zero (False)
prior_sigma_1: float -> prior sigma on the mixture prior distribution 1
prior_sigma_2: float -> prior sigma on the mixture prior distribution 2
prior_pi: float -> pi on the scaled mixture prior
posterior_mu_init float -> posterior mean for the weight mu init
posterior_rho_init float -> posterior mean for the weight rho init
freeze: bool -> wheter the model will start with frozen(deterministic) weights, or not
"""
weight_mu: torch.Tensor
weight_rho: torch.Tensor
def __init__(self,
in_features,
out_features,
bias=True,
prior_sigma_1 = 0.1,
prior_sigma_2 = 0.4,
prior_pi = 1,
posterior_mu_init = 0,
posterior_rho_init = -7.0,
freeze = False,
prior_dist = None,
device = None,
dtype = None):
super().__init__()
#our main parameters
self.in_features = in_features
self.out_features = out_features
self.bias = bias
self.freeze = freeze
self.posterior_mu_init = posterior_mu_init
self.posterior_rho_init = posterior_rho_init
#parameters for the scale mixture prior
self.prior_sigma_1 = prior_sigma_1
self.prior_sigma_2 = prior_sigma_2
self.prior_pi = prior_pi
self.prior_dist = prior_dist
# Variational weight parameters and sample
weight_mu = torch.Tensor(out_features, in_features).normal_(posterior_mu_init, 0.1)
weight_mu = weight_mu.to(device=device)
self.weight_mu = nn.Parameter(weight_mu)
weight_rho = torch.Tensor(out_features, in_features).normal_(posterior_rho_init, 0.1)
weight_rho = weight_rho.to(device=device)
self.weight_rho = nn.Parameter(weight_rho)
self.weight_sampler = TrainableRandomDistribution(self.weight_mu, self.weight_rho)
# Variational bias parameters and sample
self.bias_mu = nn.Parameter(torch.Tensor(out_features).normal_(posterior_mu_init, 0.1))
self.bias_rho = nn.Parameter(torch.Tensor(out_features).normal_(posterior_rho_init, 0.1))
self.bias_sampler = TrainableRandomDistribution(self.bias_mu, self.bias_rho)
# Priors (as BBP paper)
self.weight_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist = self.prior_dist)
self.bias_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist = self.prior_dist)
self.log_prior = 0
self.log_variational_posterior = 0
def forward(self, x):
# Sample the weights and forward it
#if the model is frozen, return frozen
if self.freeze:
return self.forward_frozen(x)
w = self.weight_sampler.sample()
if self.bias:
b = self.bias_sampler.sample()
b_log_posterior = self.bias_sampler.log_posterior()
b_log_prior = self.bias_prior_dist.log_prior(b)
else:
b = torch.zeros((self.out_features), device=x.device)
b_log_posterior = 0
b_log_prior = 0
# Get the complexity cost
self.log_variational_posterior = self.weight_sampler.log_posterior() + b_log_posterior
self.log_prior = self.weight_prior_dist.log_prior(w) + b_log_prior
return F.linear(x, w, b)
def forward_frozen(self, x):
"""
Computes the feedforward operation with the expected value for weight and biases
"""
if self.bias:
return F.linear(x, self.weight_mu, self.bias_mu)
else:
return F.linear(x, self.weight_mu, torch.zeros(self.out_features))
Do the same example as above:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
#from blitz.modules import BayesianLinear
from blitz.utils import variational_estimator
from sklearn.datasets import fetch_california_housing
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
X, y = fetch_california_housing(return_X_y=True)
X = StandardScaler().fit_transform(X)
y = StandardScaler().fit_transform(np.expand_dims(y, -1))
X_train, X_test, y_train, y_test = train_test_split(X,
y,
test_size=.25,
random_state=42)
device_gpu = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device_cpu = "cpu"
X_train_cpu, y_train_cpu = torch.tensor(X_train).float().to(device_cpu), torch.tensor(y_train).float().to(device_cpu)
X_test_cpu, y_test_cpu = torch.tensor(X_test).float().to(device_cpu), torch.tensor(y_test).float().to(device_cpu)
X_train_gpu, y_train_gpu = torch.tensor(X_train).float().to(device_gpu), torch.tensor(y_train).float().to(device_gpu)
X_test_gpu, y_test_gpu = torch.tensor(X_test).float().to(device_gpu), torch.tensor(y_test).float().to(device_gpu)
@variational_estimator
class BayesianRegressor(nn.Module):
def __init__(self, input_dim, output_dim, device):
super().__init__()
self.linear = nn.Linear(input_dim, 128)
self.blinear1 = BayesianLinear(128, 512, device=device) #, device=device
self.blinear2 = BayesianLinear(512, output_dim, device=device)
def forward(self, x):
x_ = self.linear(x)
x_ = self.blinear1(x_)
x_ = F.relu(x_)
return self.blinear2(x_)
def evaluate_regression(regressor,
X,
y,
samples = 100,
std_multiplier = 2):
preds = [regressor(X) for i in range(samples)]
preds = torch.stack(preds)
means = preds.mean(axis=0)
stds = preds.std(axis=0)
ci_upper = means + (std_multiplier * stds)
ci_lower = means - (std_multiplier * stds)
ic_acc = (ci_lower <= y) * (ci_upper >= y)
ic_acc = ic_acc.float().mean()
return ic_acc, (ci_upper >= y).float().mean(), (ci_lower <= y).float().mean()
Found a bug in your code for cpu vs gpu. The bug is shown in the jupyter notebook below. If you want to code feel free to contact me. I also provided a solution which is not perfect but maybe we can work this out together.
Load in your Model: Bayesian Linear
Import the example of your library (once all on cpu once all on gpu)
Run the example (both on cpu and gpu)
cpu:
tensor(0.0009, grad_fn=)
tensor(-6.4924, grad_fn=)
cpu:
tensor(0.0003, grad_fn=)
tensor(-6.0803, grad_fn=)
cpu:
tensor(0.0009, grad_fn=)
tensor(-5.7967, grad_fn=)
gpu:
tensor(-0.1156, device='cuda:0', grad_fn=)
tensor(-6.9336, device='cuda:0', grad_fn=)
100
gpu:
tensor(-0.1156, device='cuda:0', grad_fn=)
tensor(-6.9336, device='cuda:0', grad_fn=)
200
gpu:
tensor(-0.1156, device='cuda:0', grad_fn=)
tensor(-6.9336, device='cuda:0', grad_fn=)
300
There is a problem with the gpu -> somehow it does not work compared to the cpu
New adapted BayesianLinear:
Do the same example as above:
Run the example (both on cpu and gpu)
cpu:
tensor(0.0004, grad_fn=)
tensor(-6.6686, grad_fn=)
cpu:
tensor(-0.0003, grad_fn=)
tensor(-6.4068, grad_fn=)
cpu:
tensor(0.0024, grad_fn=)
tensor(-5.9668, grad_fn=)
gpu:
tensor(-0.0032, device='cuda:0', grad_fn=)
tensor(-6.2961, device='cuda:0', grad_fn=)
100
gpu:
tensor(0.0048, device='cuda:0', grad_fn=)
tensor(-5.7684, device='cuda:0', grad_fn=)
200
gpu:
tensor(0.0052, device='cuda:0', grad_fn=)
tensor(-5.2635, device='cuda:0', grad_fn=)
300
With this adapted class this works now
But right now you have to give another argument to the function which probably can be easily be solved
The text was updated successfully, but these errors were encountered: