Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

You forget the tanh function in the last computation in the part of def bottom_data_is(): #47

Open
mikechen66 opened this issue Dec 27, 2019 · 4 comments

Comments

@mikechen66
Copy link

mikechen66 commented Dec 27, 2019

Issue: lstm.py--the 98th line.

There is a problem with the code of line: self.state.h = self.state.s * self.state.o. You forget the tanh function. The formula is h_{t} = o_{t} * tanh(s_{t}). Therefore, the correct one is the line of code as follows.

self.state.h = tanh(self.state.s) * self.state.o

Pasted the partial lines of code as follows.

 def bottom_data_is(self, x, s_prev = None, h_prev = None):
    # if this is the first lstm node in the network
    if s_prev is None: s_prev = np.zeros_like(self.state.s)
    if h_prev is None: h_prev = np.zeros_like(self.state.h)
    # save data for use in backprop
    self.s_prev = s_prev
    self.h_prev = h_prev

    # concatenate x(t) and h(t-1)
    xc = np.hstack((x,  h_prev))
    self.state.g = np.tanh(np.dot(self.param.wg, xc) + self.param.bg)
    self.state.i = sigmoid(np.dot(self.param.wi, xc) + self.param.bi)
    self.state.f = sigmoid(np.dot(self.param.wf, xc) + self.param.bf)
    self.state.o = sigmoid(np.dot(self.param.wo, xc) + self.param.bo)
    self.state.s = self.state.g * self.state.i + s_prev * self.state.f
    self.state.h = self.state.s * self.state.o
@try1995
Copy link

try1995 commented Dec 27, 2020

self.state.h = self.state.o * np.tanh(self.state.s)

@cs-heibao
Copy link

@mikechen66
and also exists problem when do backpropagation, ignoring the derivation of tanh function

    def top_diff_is(self, top_diff_h, top_diff_s):
        # notice that top_diff_s is carried along the constant error carousel
        ds = self.state.o * top_diff_h + top_diff_s
        do = self.state.s * top_diff_h
        di = self.state.g * ds
        dg = self.state.i * ds
        df = self.s_prev * ds
ds = self.state.o *(1-self.state.s^2)* top_diff_h + top_diff_s;
do = np.tanh(self.state.s) * top_diff_h

@nicodjimenez
Copy link
Owner

I think you're right some / most implementations use the tanh but that's not how I defined the forward pass in the blog article:

https://nicodjimenez.github.io/2014/08/08/lstm.html
image

If you want to make a PR to add that as an option, that's fine with me.

@bot66
Copy link

bot66 commented Dec 31, 2021

yes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants