Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

High loss when using bayesian lstm instead of standard lstm #100

Open
amroghoneim opened this issue Dec 15, 2021 · 6 comments
Open

High loss when using bayesian lstm instead of standard lstm #100

amroghoneim opened this issue Dec 15, 2021 · 6 comments

Comments

@amroghoneim
Copy link

amroghoneim commented Dec 15, 2021

I am trying to implement a model using the bayesian lstm layer given I already have a model that relies on lstm and it gets good results for a classification task.
When I use the bayesian layer the loss becomes very high and the accuracy doesn't converge much. I tried changing the model's hyperparameters (especially prior variables and posterior_rho) but didn't that much. I also added sharpen=True for loss sharpening but nothing changed.

The model:

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
##### Bayesian version #####
from layers.lstm_bayesian_layer import BayesianLSTM
from blitz.utils import variational_estimator
from layers.linear_bayesian_layer import BayesianLinear

from layers.attention import Attention, NoQueryAttention
from layers.squeeze_embedding import SqueezeEmbedding

@variational_estimator
class LSTM_BAYES_RNN(nn.Module):
    def __init__(self, embedding_matrix, opt):
        super(LSTM_BAYES_RNN, self).__init__()
        self.lstm = BayesianLSTM(opt.embed_dim*2, opt.hidden_dim, bias=True, freeze = False,
                prior_sigma_1 = 5,
                prior_sigma_2 = 5,
                posterior_rho_init=1,
                sharpen=True)
                #  prior_pi = 1,
                #  posterior_mu_init = 0,
                #  posterior_rho_init = -6.0,
        self.opt = opt
        self.embed = nn.Embedding.from_pretrained(torch.tensor(embedding_matrix, dtype=torch.float))
        self.squeeze_embedding = SqueezeEmbedding()
        # self.dense = BayesianLinear(opt.hidden_dim, opt.polarities_dim, bias=True, freeze = False, 
                          # prior_sigma_1 = 10, prior_sigma_2 = 10, posterior_rho_init  = 5 )
        self.attention = NoQueryAttention(opt.hidden_dim+opt.embed_dim, score_function='bi_linear')

        self.dense = nn.Linear(opt.hidden_dim, opt.polarities_dim)

    def forward(self, inputs):
        text_indices, aspect_indices = inputs[0], inputs[1]
        x_len = torch.sum(text_indices != 0, dim=-1)
        x_len_max = torch.max(x_len)
        aspect_len = torch.sum(aspect_indices != 0, dim=-1).float()

        x = self.embed(text_indices)
        x = self.squeeze_embedding(x, x_len)
        aspect = self.embed(aspect_indices)
        aspect_pool = torch.div(torch.sum(aspect, dim=1), aspect_len.unsqueeze(1))
        aspect = aspect_pool.unsqueeze(1).expand(-1, x_len_max, -1)
        x = torch.cat((aspect, x), dim=-1)

        h, (_, _) = self.lstm(x)
        ha = torch.cat((h, aspect), dim=-1)
        _, score = self.attention(ha)
        output = torch.squeeze(torch.bmm(score, h), dim=1)
        out = self.dense(output)
        return out

in the training I have

                # bayesian loss calculation 
                pi_weight = minibatch_weight(batch_idx=i_batch, num_batches=self.opt.batch_size)

                loss = self.model.sample_elbo(
                        inputs=inputs,
                        labels=targets,
                        criterion=nn.CrossEntropyLoss(),
                        sample_nbr=10,
                        # complexity_cost_weight=1/len(self.trainset))
                        complexity_cost_weight = pi_weight)
      
                ##################

                loss.backward()
                optimizer.step()

                # take 3 outputs per example
                outputs = torch.stack([self.model(inputs) for i in range(3)])
                preds = torch.mean(outputs, axis=0)

What's the problem here?

@amroghoneim
Copy link
Author

This is how the loss looks like within the sample elbo method for multiple samples over many epochs

PERFORMANCE LOSS: 1.3859792947769165
PERFORMANCE + KL LOSS: 446.3642578125
PERFORMANCE LOSS: 447.8734436035156
PERFORMANCE + KL LOSS: 892.888916015625
PERFORMANCE LOSS: 894.4247436523438
PERFORMANCE + KL LOSS: 1340.0321044921875
PERFORMANCE LOSS: 1341.6724853515625
PERFORMANCE + KL LOSS: 1786.862548828125
PERFORMANCE LOSS: 1788.3409423828125
PERFORMANCE + KL LOSS: 2233.776611328125
PERFORMANCE LOSS: 2235.102783203125
PERFORMANCE + KL LOSS: 2680.375244140625
PERFORMANCE LOSS: 2681.85400390625
PERFORMANCE + KL LOSS: 3127.18310546875
PERFORMANCE LOSS: 3128.93798828125
PERFORMANCE + KL LOSS: 3574.3134765625
PERFORMANCE LOSS: 3576.058349609375
PERFORMANCE + KL LOSS: 4021.171630859375
PERFORMANCE LOSS: 4023.035888671875
PERFORMANCE + KL LOSS: 4468.0283203125
PERFORMANCE LOSS: 1.2040766477584839
PERFORMANCE + KL LOSS: 446.34429931640625
PERFORMANCE LOSS: 447.7460632324219
PERFORMANCE + KL LOSS: 892.8301391601562
PERFORMANCE LOSS: 894.1790161132812
PERFORMANCE + KL LOSS: 1339.896240234375
PERFORMANCE LOSS: 1341.173583984375
PERFORMANCE + KL LOSS: 1786.486083984375
PERFORMANCE LOSS: 1787.7799072265625
PERFORMANCE + KL LOSS: 2232.6181640625
PERFORMANCE LOSS: 2234.0400390625
PERFORMANCE + KL LOSS: 2679.76123046875
PERFORMANCE LOSS: 2680.916259765625
PERFORMANCE + KL LOSS: 3126.09130859375
PERFORMANCE LOSS: 3127.1962890625
PERFORMANCE + KL LOSS: 3572.398681640625
PERFORMANCE LOSS: 3573.755126953125
PERFORMANCE + KL LOSS: 4019.00634765625
PERFORMANCE LOSS: 4020.251708984375
PERFORMANCE + KL LOSS: 4465.36572265625
PERFORMANCE LOSS: 1.2470617294311523
PERFORMANCE + KL LOSS: 446.4764099121094
PERFORMANCE LOSS: 447.49908447265625
PERFORMANCE + KL LOSS: 892.6083984375
PERFORMANCE LOSS: 893.86865234375
PERFORMANCE + KL LOSS: 1338.82666015625
PERFORMANCE LOSS: 1340.1925048828125
PERFORMANCE + KL LOSS: 1785.356201171875
PERFORMANCE LOSS: 1786.67041015625
PERFORMANCE + KL LOSS: 2231.896484375
PERFORMANCE LOSS: 2233.141845703125
PERFORMANCE + KL LOSS: 2678.32861328125
PERFORMANCE LOSS: 2679.552734375
PERFORMANCE + KL LOSS: 3124.926513671875
PERFORMANCE LOSS: 3126.08837890625
PERFORMANCE + KL LOSS: 3571.53564453125
PERFORMANCE LOSS: 3572.876953125
PERFORMANCE + KL LOSS: 4017.953857421875
PERFORMANCE LOSS: 4019.236572265625
PERFORMANCE + KL LOSS: 4464.65234375
PERFORMANCE LOSS: 1.2529895305633545
PERFORMANCE + KL LOSS: 446.598876953125
PERFORMANCE LOSS: 447.7634582519531
PERFORMANCE + KL LOSS: 893.0445556640625
PERFORMANCE LOSS: 894.239990234375
PERFORMANCE + KL LOSS: 1339.1597900390625
PERFORMANCE LOSS: 1340.3812255859375
PERFORMANCE + KL LOSS: 1785.3369140625
PERFORMANCE LOSS: 1786.525634765625
PERFORMANCE + KL LOSS: 2231.73193359375
PERFORMANCE LOSS: 2232.955078125
PERFORMANCE + KL LOSS: 2678.22802734375
PERFORMANCE LOSS: 2679.33935546875
PERFORMANCE + KL LOSS: 3124.440185546875
PERFORMANCE LOSS: 3125.568359375
PERFORMANCE + KL LOSS: 3570.59130859375
PERFORMANCE LOSS: 3571.839111328125
PERFORMANCE + KL LOSS: 4017.313720703125
PERFORMANCE LOSS: 4018.48046875
PERFORMANCE + KL LOSS: 4463.61572265625
PERFORMANCE LOSS: 1.3086638450622559
PERFORMANCE + KL LOSS: 446.3980712890625
PERFORMANCE LOSS: 447.6965026855469
PERFORMANCE + KL LOSS: 893.017578125
PERFORMANCE LOSS: 894.0802612304688
PERFORMANCE + KL LOSS: 1339.434814453125
PERFORMANCE LOSS: 1340.6083984375
PERFORMANCE + KL LOSS: 1786.0252685546875
PERFORMANCE LOSS: 1787.18115234375
PERFORMANCE + KL LOSS: 2232.149169921875
PERFORMANCE LOSS: 2233.2109375
PERFORMANCE + KL LOSS: 2678.623779296875
PERFORMANCE LOSS: 2679.775390625
PERFORMANCE + KL LOSS: 3125.24609375
PERFORMANCE LOSS: 3126.44775390625
PERFORMANCE + KL LOSS: 3571.93017578125
PERFORMANCE LOSS: 3573.00390625
PERFORMANCE + KL LOSS: 4017.955078125
PERFORMANCE LOSS: 4019.065185546875
PERFORMANCE + KL LOSS: 4464.07568359375
PERFORMANCE LOSS: 0.9791557788848877
PERFORMANCE + KL LOSS: 445.9003601074219
PERFORMANCE LOSS: 446.8746032714844
PERFORMANCE + KL LOSS: 892.022705078125
PERFORMANCE LOSS: 892.9429321289062
PERFORMANCE + KL LOSS: 1338.173095703125
PERFORMANCE LOSS: 1339.390380859375
PERFORMANCE + KL LOSS: 1784.44775390625
PERFORMANCE LOSS: 1785.507568359375
PERFORMANCE + KL LOSS: 2230.6630859375
PERFORMANCE LOSS: 2231.741455078125
PERFORMANCE + KL LOSS: 2676.9111328125
PERFORMANCE LOSS: 2677.945556640625
PERFORMANCE + KL LOSS: 3122.991943359375
PERFORMANCE LOSS: 3123.878173828125
PERFORMANCE + KL LOSS: 3569.328857421875
PERFORMANCE LOSS: 3570.427978515625
PERFORMANCE + KL LOSS: 4015.494873046875
PERFORMANCE LOSS: 4016.546142578125
PERFORMANCE + KL LOSS: 4461.78369140625
PERFORMANCE LOSS: 1.034562349319458
PERFORMANCE + KL LOSS: 446.2547912597656
PERFORMANCE LOSS: 447.161865234375
PERFORMANCE + KL LOSS: 892.19677734375
PERFORMANCE LOSS: 893.2605590820312
PERFORMANCE + KL LOSS: 1338.6982421875
PERFORMANCE LOSS: 1339.7855224609375
PERFORMANCE + KL LOSS: 1785.085205078125
PERFORMANCE LOSS: 1786.3265380859375
PERFORMANCE + KL LOSS: 2231.59912109375
PERFORMANCE LOSS: 2232.564453125
PERFORMANCE + KL LOSS: 2677.362548828125
PERFORMANCE LOSS: 2678.36328125
PERFORMANCE + KL LOSS: 3123.44775390625
PERFORMANCE LOSS: 3124.4931640625
PERFORMANCE + KL LOSS: 3569.76220703125
PERFORMANCE LOSS: 3570.87060546875
PERFORMANCE + KL LOSS: 4015.784912109375
PERFORMANCE LOSS: 4016.75537109375
PERFORMANCE + KL LOSS: 4461.736328125
PERFORMANCE LOSS: 1.1334476470947266
PERFORMANCE + KL LOSS: 446.2291259765625
PERFORMANCE LOSS: 447.3734436035156
PERFORMANCE + KL LOSS: 892.83544921875
PERFORMANCE LOSS: 893.9830322265625
PERFORMANCE + KL LOSS: 1339.30078125
PERFORMANCE LOSS: 1340.4144287109375
PERFORMANCE + KL LOSS: 1785.761962890625
PERFORMANCE LOSS: 1787.0693359375
PERFORMANCE + KL LOSS: 2232.11376953125
PERFORMANCE LOSS: 2233.1796875
PERFORMANCE + KL LOSS: 2678.37451171875
PERFORMANCE LOSS: 2679.556396484375
PERFORMANCE + KL LOSS: 3124.7421875
PERFORMANCE LOSS: 3125.9501953125
PERFORMANCE + KL LOSS: 3571.496826171875
PERFORMANCE LOSS: 3572.64404296875
PERFORMANCE + KL LOSS: 4017.846923828125
PERFORMANCE LOSS: 4018.96044921875
PERFORMANCE + KL LOSS: 4464.28173828125
PERFORMANCE LOSS: 1.0037894248962402
PERFORMANCE + KL LOSS: 445.89324951171875
PERFORMANCE LOSS: 446.8221435546875
PERFORMANCE + KL LOSS: 891.6876220703125
PERFORMANCE LOSS: 892.6782836914062
PERFORMANCE + KL LOSS: 1337.7032470703125
PERFORMANCE LOSS: 1338.6993408203125
PERFORMANCE + KL LOSS: 1783.60986328125
PERFORMANCE LOSS: 1784.5438232421875
PERFORMANCE + KL LOSS: 2229.855712890625
PERFORMANCE LOSS: 2230.88720703125
PERFORMANCE + KL LOSS: 2675.54638671875
PERFORMANCE LOSS: 2676.505615234375
PERFORMANCE + KL LOSS: 3121.80419921875
PERFORMANCE LOSS: 3122.758056640625
PERFORMANCE + KL LOSS: 3567.829345703125
PERFORMANCE LOSS: 3568.846923828125
PERFORMANCE + KL LOSS: 4013.456298828125
PERFORMANCE LOSS: 4014.486328125
PERFORMANCE + KL LOSS: 4459.689453125
PERFORMANCE LOSS: 1.1406904458999634
PERFORMANCE + KL LOSS: 446.54901123046875
PERFORMANCE LOSS: 447.63385009765625
PERFORMANCE + KL LOSS: 892.46337890625
PERFORMANCE LOSS: 893.5294189453125
PERFORMANCE + KL LOSS: 1338.5889892578125
PERFORMANCE LOSS: 1339.5230712890625
PERFORMANCE + KL LOSS: 1784.5792236328125
PERFORMANCE LOSS: 1785.79296875
PERFORMANCE + KL LOSS: 2230.992919921875
PERFORMANCE LOSS: 2232.022216796875
PERFORMANCE + KL LOSS: 2677.153076171875
PERFORMANCE LOSS: 2678.14990234375
PERFORMANCE + KL LOSS: 3122.89013671875
PERFORMANCE LOSS: 3123.91357421875
PERFORMANCE + KL LOSS: 3569.23095703125
PERFORMANCE LOSS: 3570.322021484375
PERFORMANCE + KL LOSS: 4015.341796875
PERFORMANCE LOSS: 4016.3662109375
PERFORMANCE + KL LOSS: 4461.22021484375
loss: 446.3445, acc: 0.4250
PERFORMANCE LOSS: 1.185144066810608
PERFORMANCE + KL LOSS: 446.47686767578125
PERFORMANCE LOSS: 447.7961730957031
PERFORMANCE + KL LOSS: 893.0511474609375
PERFORMANCE LOSS: 894.4295043945312
PERFORMANCE + KL LOSS: 1339.2666015625
PERFORMANCE LOSS: 1340.646240234375
PERFORMANCE + KL LOSS: 1785.5106201171875
PERFORMANCE LOSS: 1786.9635009765625
PERFORMANCE + KL LOSS: 2231.67919921875
PERFORMANCE LOSS: 2233.009521484375
PERFORMANCE + KL LOSS: 2678.06396484375
PERFORMANCE LOSS: 2679.30859375
PERFORMANCE + KL LOSS: 3124.505859375
PERFORMANCE LOSS: 3125.79931640625
PERFORMANCE + KL LOSS: 3570.76025390625
PERFORMANCE LOSS: 3572.018310546875
PERFORMANCE + KL LOSS: 4017.48828125
PERFORMANCE LOSS: 4018.740966796875
PERFORMANCE + KL LOSS: 4463.671875
PERFORMANCE LOSS: 1.2589664459228516
PERFORMANCE + KL LOSS: 446.436767578125
PERFORMANCE LOSS: 447.7485046386719
PERFORMANCE + KL LOSS: 892.6570434570312
PERFORMANCE LOSS: 893.7003784179688
PERFORMANCE + KL LOSS: 1338.519287109375
PERFORMANCE LOSS: 1339.8040771484375
PERFORMANCE + KL LOSS: 1784.7545166015625
PERFORMANCE LOSS: 1785.898193359375
PERFORMANCE + KL LOSS: 2231.010498046875
PERFORMANCE LOSS: 2232.1064453125
PERFORMANCE + KL LOSS: 2677.271240234375
PERFORMANCE LOSS: 2678.56689453125
PERFORMANCE + KL LOSS: 3123.6142578125
PERFORMANCE LOSS: 3124.657470703125
PERFORMANCE + KL LOSS: 3569.591064453125
PERFORMANCE LOSS: 3570.9541015625
PERFORMANCE + KL LOSS: 4016.14990234375
PERFORMANCE LOSS: 4017.164306640625
PERFORMANCE + KL LOSS: 4462.57470703125
PERFORMANCE LOSS: 0.9719462990760803
PERFORMANCE + KL LOSS: 445.69403076171875
PERFORMANCE LOSS: 446.7107849121094
PERFORMANCE + KL LOSS: 891.2470092773438
PERFORMANCE LOSS: 892.1383056640625
PERFORMANCE + KL LOSS: 1337.050537109375
PERFORMANCE LOSS: 1337.9063720703125
PERFORMANCE + KL LOSS: 1782.8287353515625
PERFORMANCE LOSS: 1783.8055419921875
PERFORMANCE + KL LOSS: 2228.96337890625
PERFORMANCE LOSS: 2229.83935546875
PERFORMANCE + KL LOSS: 2674.52294921875
PERFORMANCE LOSS: 2675.642333984375
PERFORMANCE + KL LOSS: 3120.78564453125
PERFORMANCE LOSS: 3121.7060546875
PERFORMANCE + KL LOSS: 3566.986328125
PERFORMANCE LOSS: 3567.8515625
PERFORMANCE + KL LOSS: 4012.77001953125
PERFORMANCE LOSS: 4013.5654296875
PERFORMANCE + KL LOSS: 4458.6123046875
PERFORMANCE LOSS: 0.998976469039917
PERFORMANCE + KL LOSS: 446.0976867675781
PERFORMANCE LOSS: 447.0158386230469
PERFORMANCE + KL LOSS: 891.9049072265625
PERFORMANCE LOSS: 892.7996215820312
PERFORMANCE + KL LOSS: 1337.7274169921875
PERFORMANCE LOSS: 1338.74462890625
PERFORMANCE + KL LOSS: 1783.6373291015625
PERFORMANCE LOSS: 1784.5079345703125
PERFORMANCE + KL LOSS: 2229.2099609375
PERFORMANCE LOSS: 2230.24609375
PERFORMANCE + KL LOSS: 2675.1591796875
PERFORMANCE LOSS: 2676.1318359375
PERFORMANCE + KL LOSS: 3121.31396484375
PERFORMANCE LOSS: 3122.311279296875
PERFORMANCE + KL LOSS: 3567.53173828125
PERFORMANCE LOSS: 3568.599609375
PERFORMANCE + KL LOSS: 4013.51220703125
PERFORMANCE LOSS: 4014.64306640625
PERFORMANCE + KL LOSS: 4459.9609375
PERFORMANCE LOSS: 0.9898296594619751
PERFORMANCE + KL LOSS: 446.1213073730469
PERFORMANCE LOSS: 447.1927185058594
PERFORMANCE + KL LOSS: 891.9871215820312
PERFORMANCE LOSS: 893.0206909179688
PERFORMANCE + KL LOSS: 1338.072998046875
PERFORMANCE LOSS: 1338.9912109375
PERFORMANCE + KL LOSS: 1784.236572265625
PERFORMANCE LOSS: 1785.2718505859375
PERFORMANCE + KL LOSS: 2229.968994140625
PERFORMANCE LOSS: 2230.87890625
PERFORMANCE + KL LOSS: 2676.033203125
PERFORMANCE LOSS: 2677.09619140625
PERFORMANCE + KL LOSS: 3121.7255859375
PERFORMANCE LOSS: 3122.779296875
PERFORMANCE + KL LOSS: 3567.864501953125
PERFORMANCE LOSS: 3568.97705078125
PERFORMANCE + KL LOSS: 4013.9189453125
PERFORMANCE LOSS: 4014.874267578125
PERFORMANCE + KL LOSS: 4460.056640625

@Philippe-Drolet
Copy link

Having the same issue here...

@swjtufjs
Copy link

Hello, have you finally solved it? I had the same problem

@0wenwu
Copy link

0wenwu commented Nov 30, 2022

Have the same problem. And how to output the uncertainty of prediction result?

@piEsposito
Copy link
Owner

Hey, I'm sorry for the delay. Will try to take a look at it this week.

Maybe reducing the KL Divergence weight on the loss could help.

@0wenwu to output the uncertainty you do multiple forward passes and check the variance, you can assume it is a normal.

@0wenwu
Copy link

0wenwu commented Jan 15, 2023

Hey, I'm sorry for the delay. Will try to take a look at it this week.

Maybe reducing the KL Divergence weight on the loss could help.

@0wenwu to output the uncertainty you do multiple forward passes and check the variance, you can assume it is a normal.

Hey, @piEsposito
I have solved the problem, thank you for your excellent work.
What we need to do is reading the stocks-blstm.ipynb carefully.
What do you think of the effect of sample_nbr on the predict loss?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants