Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU vs CPU Bug #118

Open
MarkDanielArndt opened this issue Sep 25, 2023 · 0 comments
Open

GPU vs CPU Bug #118

MarkDanielArndt opened this issue Sep 25, 2023 · 0 comments

Comments

@MarkDanielArndt
Copy link

Found a bug in your code for cpu vs gpu. The bug is shown in the jupyter notebook below. If you want to code feel free to contact me. I also provided a solution which is not perfect but maybe we can work this out together.

Load in your Model: Bayesian Linear

import torch
from torch import nn
from torch.nn import functional as F
from blitz.modules.base_bayesian_module import BayesianModule
from blitz.modules.weight_sampler import TrainableRandomDistribution, PriorWeightDistribution


class BayesianLinear(BayesianModule):
    """
    Bayesian Linear layer, implements the linear layer proposed on Weight Uncertainity on Neural Networks
    (Bayes by Backprop paper).

    Its objective is be interactable with torch nn.Module API, being able even to be chained in nn.Sequential models with other non-this-lib layers
    
    parameters:
        in_fetaures: int -> incoming features for the layer
        out_features: int -> output features for the layer
        bias: bool -> whether the bias will exist (True) or set to zero (False)
        prior_sigma_1: float -> prior sigma on the mixture prior distribution 1
        prior_sigma_2: float -> prior sigma on the mixture prior distribution 2
        prior_pi: float -> pi on the scaled mixture prior
        posterior_mu_init float -> posterior mean for the weight mu init
        posterior_rho_init float -> posterior mean for the weight rho init
        freeze: bool -> wheter the model will start with frozen(deterministic) weights, or not
    
    """
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 prior_sigma_1 = 0.1,
                 prior_sigma_2 = 0.4,
                 prior_pi = 1,
                 posterior_mu_init = 0,
                 posterior_rho_init = -7.0,
                 freeze = False,
                 prior_dist = None):
        super().__init__()

        #our main parameters
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        self.freeze = freeze


        self.posterior_mu_init = posterior_mu_init
        self.posterior_rho_init = posterior_rho_init

        #parameters for the scale mixture prior
        self.prior_sigma_1 = prior_sigma_1
        self.prior_sigma_2 = prior_sigma_2
        self.prior_pi = prior_pi
        self.prior_dist = prior_dist

        # Variational weight parameters and sample
        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features).normal_(posterior_mu_init, 0.1))
        self.weight_rho = nn.Parameter(torch.Tensor(out_features, in_features).normal_(posterior_rho_init, 0.1))
        self.weight_sampler = TrainableRandomDistribution(self.weight_mu, self.weight_rho)

        # Variational bias parameters and sample
        self.bias_mu = nn.Parameter(torch.Tensor(out_features).normal_(posterior_mu_init, 0.1))
        self.bias_rho = nn.Parameter(torch.Tensor(out_features).normal_(posterior_rho_init, 0.1))
        self.bias_sampler = TrainableRandomDistribution(self.bias_mu, self.bias_rho)

        # Priors (as BBP paper)
        self.weight_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist = self.prior_dist)
        self.bias_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist = self.prior_dist)
        self.log_prior = 0
        self.log_variational_posterior = 0

    def forward(self, x):
        # Sample the weights and forward it
        
        #if the model is frozen, return frozen
        if self.freeze:
            return self.forward_frozen(x)

        w = self.weight_sampler.sample()

        if self.bias:
            b = self.bias_sampler.sample()
            b_log_posterior = self.bias_sampler.log_posterior()
            b_log_prior = self.bias_prior_dist.log_prior(b)

        else:
            b = torch.zeros((self.out_features), device=x.device)
            b_log_posterior = 0
            b_log_prior = 0

        # Get the complexity cost
        self.log_variational_posterior = self.weight_sampler.log_posterior() + b_log_posterior
        self.log_prior = self.weight_prior_dist.log_prior(w) + b_log_prior

        return F.linear(x, w, b)

    def forward_frozen(self, x):
        """
        Computes the feedforward operation with the expected value for weight and biases
        """
        if self.bias:
            return F.linear(x, self.weight_mu, self.bias_mu)
        else:
            return F.linear(x, self.weight_mu, torch.zeros(self.out_features))

Import the example of your library (once all on cpu once all on gpu)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

#from blitz.modules import BayesianLinear
from blitz.utils import variational_estimator

from sklearn.datasets import fetch_california_housing
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

X, y = fetch_california_housing(return_X_y=True)
X = StandardScaler().fit_transform(X)
y = StandardScaler().fit_transform(np.expand_dims(y, -1))

X_train, X_test, y_train, y_test = train_test_split(X,
                                                    y,
                                                    test_size=.25,
                                                    random_state=42)

device_gpu = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device_cpu = "cpu"

X_train_cpu, y_train_cpu = torch.tensor(X_train).float().to(device_cpu), torch.tensor(y_train).float().to(device_cpu)
X_test_cpu, y_test_cpu = torch.tensor(X_test).float().to(device_cpu), torch.tensor(y_test).float().to(device_cpu)

X_train_gpu, y_train_gpu = torch.tensor(X_train).float().to(device_gpu), torch.tensor(y_train).float().to(device_gpu)
X_test_gpu, y_test_gpu = torch.tensor(X_test).float().to(device_gpu), torch.tensor(y_test).float().to(device_gpu)

@variational_estimator
class BayesianRegressor(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, 128)
        self.blinear1 = BayesianLinear(128, 512)#.to(device) #, device=device
        self.blinear2 = BayesianLinear(512, output_dim)

    def forward(self, x):
        x_ = self.linear(x)
        x_ = self.blinear1(x_)
        x_ = F.relu(x_)
        return self.blinear2(x_)


def evaluate_regression(regressor,
                        X,
                        y,
                        samples = 100,
                        std_multiplier = 2):
    preds = [regressor(X) for i in range(samples)]
    preds = torch.stack(preds)
    means = preds.mean(axis=0)
    stds = preds.std(axis=0)
    ci_upper = means + (std_multiplier * stds)
    ci_lower = means - (std_multiplier * stds)
    ic_acc = (ci_lower <= y) * (ci_upper >= y)
    ic_acc = ic_acc.float().mean()
    return ic_acc, (ci_upper >= y).float().mean(), (ci_lower <= y).float().mean()
regressor_cpu = BayesianRegressor(8, 1).to(device_cpu)
regressor_gpu = BayesianRegressor(8, 1).to(device_gpu)
optimizer_cpu = optim.Adam(regressor_cpu.parameters(), lr=0.01)
optimizer_gpu = optim.Adam(regressor_gpu.parameters(), lr=0.01)
criterion = torch.nn.MSELoss()

ds_train_cpu = torch.utils.data.TensorDataset(X_train_cpu, y_train_cpu)
dataloader_train_cpu = torch.utils.data.DataLoader(ds_train_cpu, batch_size=16, shuffle=True)

ds_test_cpu = torch.utils.data.TensorDataset(X_test_cpu, y_test_cpu)
dataloader_test_cpu = torch.utils.data.DataLoader(ds_test_cpu, batch_size=16, shuffle=True)

ds_train_gpu = torch.utils.data.TensorDataset(X_train_gpu, y_train_gpu)
dataloader_train_gpu = torch.utils.data.DataLoader(ds_train_gpu, batch_size=16, shuffle=True)

ds_test_gpu = torch.utils.data.TensorDataset(X_test_gpu, y_test_gpu)
dataloader_test_gpu = torch.utils.data.DataLoader(ds_test_gpu, batch_size=16, shuffle=True)

Run the example (both on cpu and gpu)

iteration = 0
for epoch in range(1):
    for i, (datapoints, labels) in enumerate(dataloader_train_cpu):
        optimizer_cpu.zero_grad()

        loss = regressor_cpu.sample_elbo(inputs=datapoints.to(device_cpu),
                           labels=labels.to(device_cpu),
                           criterion=criterion,
                           sample_nbr=3,
                           complexity_cost_weight=10000)
        loss.backward()
        optimizer_cpu.step()

        iteration += 1
        if iteration%100==0:
            ic_acc, under_ci_upper, over_ci_lower = evaluate_regression(regressor_cpu,
                                                                        X_test_cpu.to(device_cpu),
                                                                        y_test_cpu.to(device_cpu),
                                                                        samples=25,
                                                                        std_multiplier=3)
            print("cpu:")
            print(regressor_cpu.blinear1.weight_mu[0,4])
            print(regressor_cpu.blinear1.weight_rho[0,4])
        if iteration == 300:
            break
            
iteration = 0
for epoch in range(1):
    for i, (datapoints, labels) in enumerate(dataloader_train_gpu):
        optimizer_gpu.zero_grad()

        loss = regressor_gpu.sample_elbo(inputs=datapoints.to(device_gpu),
                           labels=labels.to(device_gpu),
                           criterion=criterion,
                           sample_nbr=3,
                           complexity_cost_weight=10000)
        loss.backward()
        optimizer_gpu.step()

        iteration += 1
        if iteration%100==0:
            ic_acc, under_ci_upper, over_ci_lower = evaluate_regression(regressor_gpu,
                                                                        X_test_gpu.to(device_gpu),
                                                                        y_test_gpu.to(device_gpu),
                                                                        samples=25,
                                                                        std_multiplier=3)
            print("gpu:")
            print(regressor_gpu.blinear1.weight_mu[0,4])
            print(regressor_gpu.blinear1.weight_rho[0,4])
            #print("CI acc: {:.2f}, CI upper acc: {:.2f}, CI lower acc: {:.2f}".format(ic_acc, under_ci_upper, over_ci_lower))
            #print("Loss: {:.4f}".format(loss))
            print(iteration)
        if iteration == 300:
            break

cpu:
tensor(0.0009, grad_fn=)
tensor(-6.4924, grad_fn=)
cpu:
tensor(0.0003, grad_fn=)
tensor(-6.0803, grad_fn=)
cpu:
tensor(0.0009, grad_fn=)
tensor(-5.7967, grad_fn=)
gpu:
tensor(-0.1156, device='cuda:0', grad_fn=)
tensor(-6.9336, device='cuda:0', grad_fn=)
100
gpu:
tensor(-0.1156, device='cuda:0', grad_fn=)
tensor(-6.9336, device='cuda:0', grad_fn=)
200
gpu:
tensor(-0.1156, device='cuda:0', grad_fn=)
tensor(-6.9336, device='cuda:0', grad_fn=)
300

There is a problem with the gpu -> somehow it does not work compared to the cpu

New adapted BayesianLinear:

import torch
from torch import nn
from torch.nn import functional as F
from blitz.modules.base_bayesian_module import BayesianModule
from blitz.modules.weight_sampler import TrainableRandomDistribution, PriorWeightDistribution
from torch.nn.parameter import Parameter, UninitializedParameter
from torch.nn import init


class BayesianLinear(BayesianModule):
    """
    Bayesian Linear layer, implements the linear layer proposed on Weight Uncertainity on Neural Networks
    (Bayes by Backprop paper).
    Its objective is be interactable with torch nn.Module API, being able even to be chained in nn.Sequential models with other non-this-lib layers

    parameters:
        in_fetaures: int -> incoming features for the layer
        out_features: int -> output features for the layer
        bias: bool -> whether the bias will exist (True) or set to zero (False)
        prior_sigma_1: float -> prior sigma on the mixture prior distribution 1
        prior_sigma_2: float -> prior sigma on the mixture prior distribution 2
        prior_pi: float -> pi on the scaled mixture prior
        posterior_mu_init float -> posterior mean for the weight mu init
        posterior_rho_init float -> posterior mean for the weight rho init
        freeze: bool -> wheter the model will start with frozen(deterministic) weights, or not

    """
    weight_mu: torch.Tensor
    weight_rho: torch.Tensor

    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 prior_sigma_1 = 0.1,
                 prior_sigma_2 = 0.4,
                 prior_pi = 1,
                 posterior_mu_init = 0,
                 posterior_rho_init = -7.0,
                 freeze = False,
                 prior_dist = None,
                 device = None,
                 dtype = None):
        
        super().__init__()

        #our main parameters
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        self.freeze = freeze


        self.posterior_mu_init = posterior_mu_init
        self.posterior_rho_init = posterior_rho_init

        #parameters for the scale mixture prior
        self.prior_sigma_1 = prior_sigma_1
        self.prior_sigma_2 = prior_sigma_2
        self.prior_pi = prior_pi
        self.prior_dist = prior_dist


        # Variational weight parameters and sample
        weight_mu = torch.Tensor(out_features, in_features).normal_(posterior_mu_init, 0.1)
        weight_mu = weight_mu.to(device=device)
        self.weight_mu = nn.Parameter(weight_mu)
        
        weight_rho = torch.Tensor(out_features, in_features).normal_(posterior_rho_init, 0.1)
        weight_rho = weight_rho.to(device=device)
        self.weight_rho = nn.Parameter(weight_rho)
        
        self.weight_sampler = TrainableRandomDistribution(self.weight_mu, self.weight_rho)

        # Variational bias parameters and sample
        self.bias_mu = nn.Parameter(torch.Tensor(out_features).normal_(posterior_mu_init, 0.1))
        self.bias_rho = nn.Parameter(torch.Tensor(out_features).normal_(posterior_rho_init, 0.1))
        self.bias_sampler = TrainableRandomDistribution(self.bias_mu, self.bias_rho)

        # Priors (as BBP paper)
        self.weight_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist = self.prior_dist)
        self.bias_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist = self.prior_dist)
        self.log_prior = 0
        self.log_variational_posterior = 0

    def forward(self, x):
        # Sample the weights and forward it

        #if the model is frozen, return frozen
        if self.freeze:
            return self.forward_frozen(x)

        w = self.weight_sampler.sample()

        if self.bias:
            b = self.bias_sampler.sample()
            b_log_posterior = self.bias_sampler.log_posterior()
            b_log_prior = self.bias_prior_dist.log_prior(b)

        else:
            b = torch.zeros((self.out_features), device=x.device)
            b_log_posterior = 0
            b_log_prior = 0

        # Get the complexity cost
        self.log_variational_posterior = self.weight_sampler.log_posterior() + b_log_posterior
        self.log_prior = self.weight_prior_dist.log_prior(w) + b_log_prior

        return F.linear(x, w, b)

    def forward_frozen(self, x):
        """
        Computes the feedforward operation with the expected value for weight and biases
        """
        if self.bias:
            return F.linear(x, self.weight_mu, self.bias_mu)
        else:
            return F.linear(x, self.weight_mu, torch.zeros(self.out_features))

Do the same example as above:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

#from blitz.modules import BayesianLinear
from blitz.utils import variational_estimator

from sklearn.datasets import fetch_california_housing
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

X, y = fetch_california_housing(return_X_y=True)
X = StandardScaler().fit_transform(X)
y = StandardScaler().fit_transform(np.expand_dims(y, -1))

X_train, X_test, y_train, y_test = train_test_split(X,
                                                    y,
                                                    test_size=.25,
                                                    random_state=42)

device_gpu = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device_cpu = "cpu"

X_train_cpu, y_train_cpu = torch.tensor(X_train).float().to(device_cpu), torch.tensor(y_train).float().to(device_cpu)
X_test_cpu, y_test_cpu = torch.tensor(X_test).float().to(device_cpu), torch.tensor(y_test).float().to(device_cpu)

X_train_gpu, y_train_gpu = torch.tensor(X_train).float().to(device_gpu), torch.tensor(y_train).float().to(device_gpu)
X_test_gpu, y_test_gpu = torch.tensor(X_test).float().to(device_gpu), torch.tensor(y_test).float().to(device_gpu)

@variational_estimator
class BayesianRegressor(nn.Module):
    def __init__(self, input_dim, output_dim, device):
        super().__init__()
        self.linear = nn.Linear(input_dim, 128)
        self.blinear1 = BayesianLinear(128, 512, device=device) #, device=device
        self.blinear2 = BayesianLinear(512, output_dim, device=device)

    def forward(self, x):
        x_ = self.linear(x)
        x_ = self.blinear1(x_)
        x_ = F.relu(x_)
        return self.blinear2(x_)


def evaluate_regression(regressor,
                        X,
                        y,
                        samples = 100,
                        std_multiplier = 2):
    preds = [regressor(X) for i in range(samples)]
    preds = torch.stack(preds)
    means = preds.mean(axis=0)
    stds = preds.std(axis=0)
    ci_upper = means + (std_multiplier * stds)
    ci_lower = means - (std_multiplier * stds)
    ic_acc = (ci_lower <= y) * (ci_upper >= y)
    ic_acc = ic_acc.float().mean()
    return ic_acc, (ci_upper >= y).float().mean(), (ci_lower <= y).float().mean()
regressor_cpu = BayesianRegressor(8, 1, device=device_cpu).to(device_cpu)
regressor_gpu = BayesianRegressor(8, 1, device=device_gpu).to(device_gpu)
optimizer_cpu = optim.Adam(regressor_cpu.parameters(), lr=0.01)
optimizer_gpu = optim.Adam(regressor_gpu.parameters(), lr=0.01)
criterion = torch.nn.MSELoss()

ds_train_cpu = torch.utils.data.TensorDataset(X_train_cpu, y_train_cpu)
dataloader_train_cpu = torch.utils.data.DataLoader(ds_train_cpu, batch_size=16, shuffle=True)

ds_test_cpu = torch.utils.data.TensorDataset(X_test_cpu, y_test_cpu)
dataloader_test_cpu = torch.utils.data.DataLoader(ds_test_cpu, batch_size=16, shuffle=True)

ds_train_gpu = torch.utils.data.TensorDataset(X_train_gpu, y_train_gpu)
dataloader_train_gpu = torch.utils.data.DataLoader(ds_train_gpu, batch_size=16, shuffle=True)

ds_test_gpu = torch.utils.data.TensorDataset(X_test_gpu, y_test_gpu)
dataloader_test_gpu = torch.utils.data.DataLoader(ds_test_gpu, batch_size=16, shuffle=True)

Run the example (both on cpu and gpu)

iteration = 0
for epoch in range(1):
    for i, (datapoints, labels) in enumerate(dataloader_train_cpu):
        optimizer_cpu.zero_grad()

        loss = regressor_cpu.sample_elbo(inputs=datapoints.to(device_cpu),
                           labels=labels.to(device_cpu),
                           criterion=criterion,
                           sample_nbr=3,
                           complexity_cost_weight=10000)
        loss.backward()
        optimizer_cpu.step()

        iteration += 1
        if iteration%100==0:
            ic_acc, under_ci_upper, over_ci_lower = evaluate_regression(regressor_cpu,
                                                                        X_test_cpu.to(device_cpu),
                                                                        y_test_cpu.to(device_cpu),
                                                                        samples=25,
                                                                        std_multiplier=3)
            print("cpu:")
            print(regressor_cpu.blinear1.weight_mu[0,4])
            print(regressor_cpu.blinear1.weight_rho[0,4])
        if iteration == 300:
            break
            
iteration = 0
for epoch in range(1):
    for i, (datapoints, labels) in enumerate(dataloader_train_gpu):
        optimizer_gpu.zero_grad()

        loss = regressor_gpu.sample_elbo(inputs=datapoints.to(device_gpu),
                           labels=labels.to(device_gpu),
                           criterion=criterion,
                           sample_nbr=3,
                           complexity_cost_weight=10000)
        loss.backward()
        optimizer_gpu.step()

        iteration += 1
        if iteration%100==0:
            ic_acc, under_ci_upper, over_ci_lower = evaluate_regression(regressor_gpu,
                                                                        X_test_gpu.to(device_gpu),
                                                                        y_test_gpu.to(device_gpu),
                                                                        samples=25,
                                                                        std_multiplier=3)
            print("gpu:")
            print(regressor_gpu.blinear1.weight_mu[0,4])
            print(regressor_gpu.blinear1.weight_rho[0,4])
            #print("CI acc: {:.2f}, CI upper acc: {:.2f}, CI lower acc: {:.2f}".format(ic_acc, under_ci_upper, over_ci_lower))
            #print("Loss: {:.4f}".format(loss))
            print(iteration)
        if iteration == 300:
            break

cpu:
tensor(0.0004, grad_fn=)
tensor(-6.6686, grad_fn=)
cpu:
tensor(-0.0003, grad_fn=)
tensor(-6.4068, grad_fn=)
cpu:
tensor(0.0024, grad_fn=)
tensor(-5.9668, grad_fn=)
gpu:
tensor(-0.0032, device='cuda:0', grad_fn=)
tensor(-6.2961, device='cuda:0', grad_fn=)
100
gpu:
tensor(0.0048, device='cuda:0', grad_fn=)
tensor(-5.7684, device='cuda:0', grad_fn=)
200
gpu:
tensor(0.0052, device='cuda:0', grad_fn=)
tensor(-5.2635, device='cuda:0', grad_fn=)
300

With this adapted class this works now

But right now you have to give another argument to the function which probably can be easily be solved

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant