-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLSTM_model.py
43 lines (23 loc) · 1.22 KB
/
LSTM_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn
class LSTMmodel(nn.Module):
def __init__(self,input_size,hidden_size_1,hidden_size_2, hidden_size_3, out_size):
super().__init__()
self.hidden_size_1 = hidden_size_1
self.hidden_size_2 = hidden_size_2
self.hidden_size_3 = hidden_size_3
self.input_size = input_size
self.lstm_1 = nn.LSTM(input_size,hidden_size_1,num_layers=1).float()
self.lstm_2 = nn.LSTM(hidden_size_1,hidden_size_2,num_layers=1).float()
self.lstm_3 = nn.LSTM(hidden_size_2,hidden_size_3,num_layers=1).float()
self.linear = nn.Linear(hidden_size_3,out_size)
self.hidden_1 = (torch.zeros(1,1,hidden_size_1), torch.zeros(1,1,hidden_size_1))
self.hidden_2 = (torch.zeros(1,1,hidden_size_2), torch.zeros(1,1,hidden_size_2))
self.hidden_3 = (torch.zeros(1,1,hidden_size_3), torch.zeros(1,1,hidden_size_3))
def forward(self,seq):
self.lstm_1 = self.lstm_1.float()
lstm_out_1 , self.hidden_1 = self.lstm_1(seq)
lstm_out_2 , self.hidden_2 = self.lstm_2(lstm_out_1)
lstm_out_3 , self.hidden_3 = self.lstm_3(lstm_out_2)
pred = self.linear(lstm_out_3.view(len(seq),-1))
return pred