Skip to content

Commit 302f35a

Browse files
committed
initial commit
0 parents  commit 302f35a

File tree

7 files changed

+459
-0
lines changed

7 files changed

+459
-0
lines changed

.gitignore

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
env/
12+
build/
13+
develop-eggs/
14+
dist/
15+
downloads/
16+
eggs/
17+
.eggs/
18+
lib/
19+
lib64/
20+
parts/
21+
sdist/
22+
var/
23+
*.egg-info/
24+
.installed.cfg
25+
*.egg
26+
27+
# PyInstaller
28+
# Usually these files are written by a python script from a template
29+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
30+
*.manifest
31+
*.spec
32+
33+
# Installer logs
34+
pip-log.txt
35+
pip-delete-this-directory.txt
36+
37+
# Unit test / coverage reports
38+
htmlcov/
39+
.tox/
40+
.coverage
41+
.coverage.*
42+
.cache
43+
nosetests.xml
44+
coverage.xml
45+
*,cover
46+
.hypothesis/
47+
*.pkl
48+
49+
# Translations
50+
*.mo
51+
*.pot
52+
53+
# Django stuff:
54+
*.log
55+
local_settings.py
56+
57+
# Flask stuff:
58+
instance/
59+
.webassets-cache
60+
61+
# Scrapy stuff:
62+
.scrapy
63+
64+
# Sphinx documentation.
65+
docs/_build/
66+
docs/_build_html/
67+
.buildinfo
68+
69+
# PyBuilder:
70+
target/
71+
72+
# IPython Notebook:
73+
*.ipynb_checkpoints
74+
75+
# pyenv:
76+
*.python-version
77+
78+
# VIM:
79+
*.swp
80+
81+
# celery beat schedule file:
82+
celerybeat-schedule
83+
84+
# dotenv:
85+
.env
86+
87+
# virtualenv:
88+
venv/
89+
ENV/
90+
91+
# Rope project settings:
92+
.ropeproject
93+
94+
# PyCharm project settings:
95+
.idea/*

README.md

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Adaptive LSTM (aLSTM)
2+
3+
[PyTorch](https://pytorch.org/) implementation of the adaptive LSTM ([https://arxiv.org/abs/1805.08574](https://arxiv.org/abs/1805.08574)).
4+
5+
aLSTM is an extension of the standard LSTM that implements adaptive parameterization.
6+
Adaptive parameterization increases model flexibility given a parameter budget, allowing
7+
more flexible and statistically efficient models. The aLSTM typically converges faster
8+
than the LSTM and reaches better generalizing performance. It also very stable; no need to
9+
use gradient clipping, even for sequences of up to thousands of terms.
10+
11+
If you use this code in research or our results in your research, please cite
12+
13+
```
14+
@article{Flennerhag:2018alstm,
15+
title = {{Breaking the Activation Function Bottleneck through Adaptive Parameterization}},
16+
author = {Flennerhag, Sebastian and Hujun, Yin and Keane, John and Elliot, Mark},
17+
journal = {{arXiv preprint, arXiv:1805.08574}},
18+
year = {2018}
19+
}
20+
```
21+
22+
## Requirements
23+
24+
This codebase should run on any [PyTorch](https://pytorch.org/) version, but has been tested for v2–v4. To install:
25+
26+
```bash
27+
git clone https://github.com/flennerhag/alstm; cd alstm
28+
python setup.py install
29+
```
30+
31+
## Usage
32+
33+
This implementation follows the official LSTM implementation in the official (and constantly changing)
34+
[PyTorch repo](https://github.com/pytorch/pytorch). We expose an ``alstm_cell`` function and its ``aLSTMCell``
35+
module wrapper. These apply to a given time step. The ``aLSTM`` class is the primary object. To run the aLSTM,
36+
use it as you would the ``LSTM`` class:
37+
38+
```python
39+
import torch
40+
from torch.autograd import Variable
41+
from alstm import aLSTM
42+
43+
seq_len, batch_size, hidden_size, adapt_size = 20, 5, 10, 3
44+
45+
alstm = aLSTM(hidden_size, hidden_size, adapt_size)
46+
47+
X = Variable(torch.rand(seq_len, batch_size, hidden_size))
48+
out, hidden = alstm(X)
49+
```
50+
51+
## Examples
52+
53+
To replicate the original experiments of the [aLSTM paper](https://arxiv.org/abs/1805.08574) head to
54+
[https://github.com/flennerhag/adaptive_parameterization](https://github.com/flennerhag/adaptive_parameterization).
55+
56+
## Contributions
57+
58+
If you spot a bug, think the docs are useless or have an idea for an extension, don't hesitate to send a PR!
59+
If your contribution is substantial, please raise an issue first to check that your idea is in line with the
60+
scope of this repo. Quick wins that would be great to have are:
61+
62+
- Support for bidirectional aLSTM
63+
- Support PyTorch's PackedSequence

alstm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .alstm import alstm_cell, aLSTMCell, aLSTM

alstm/alstm.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
"""adaptive LSTM
2+
3+
PyTorch implementation of the adaptive LSTM (https://arxiv.org/abs/1805.08574).
4+
"""
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
from torch.nn import Parameter
9+
from torch.autograd import Variable
10+
from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend
11+
12+
from .utils import Project, VariationalDropout, chunk, convert
13+
14+
# pylint: disable=too-many-locals,too-many-arguments,redefined-builtin
15+
16+
17+
def alstm_cell(input, hidden, adapt, weights, bias=None):
18+
"""The adaptive LSTM Cell for one time step."""
19+
hx, cx = hidden
20+
21+
hidden_size, input_size = hidden.size(1), input.size(1)
22+
chunks = [input_size + hidden_size, 8 * hidden_size]
23+
if bias is not None:
24+
chunks.append(4 * hidden_size)
25+
26+
adapt = chunk(adapt, chunks, 1)
27+
28+
input = torch.cat([input, hx], 1) * adapt.pop(0)
29+
gates = F.linear(input, weights) * adapt.pop(0)
30+
31+
igates, hgates = gates.chunk(2, 1)
32+
if bias is not None:
33+
hgates = hgates + bias * adapt.pop(0)
34+
35+
if input.is_cuda:
36+
state = fusedBackend.LSTMFused.apply
37+
return state(igates, hgates, cx)
38+
39+
gates = igates + hgates
40+
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
41+
42+
ingate = F.sigmoid(ingate)
43+
forgetgate = F.sigmoid(forgetgate)
44+
cellgate = F.tanh(cellgate)
45+
outgate = F.sigmoid(outgate)
46+
47+
cy = (forgetgate * cx) + (ingate * cellgate)
48+
hy = outgate * F.tanh(cy)
49+
50+
return hy, cy
51+
52+
53+
class aLSTMCell(nn.modules.rnn.RNNCellBase):
54+
55+
"""Adaptive LSTM Cell
56+
"""
57+
58+
def __init__(self, input_size, hidden_size, use_bias=True):
59+
super(aLSTMCell, self).__init__()
60+
self.input_size = input_size
61+
self.hidden_size = hidden_size
62+
self.use_bias = use_bias
63+
self.weights = Parameter(torch.Tensor(8 * hidden_size, hidden_size + input_size))
64+
if use_bias:
65+
self.bias = Parameter(torch.Tensor(4 * hidden_size))
66+
else:
67+
self.register_parameter('bias', None)
68+
self.reset_parameters()
69+
70+
def reset_parameters(self):
71+
"""Initialization of parameters"""
72+
nn.init.orthogonal(self.weights)
73+
if self.use_bias:
74+
self.bias.data.zero_()
75+
# Forget gate bias initialization
76+
self.bias.data[self.hidden_size:2*self.hidden_size] += 1
77+
78+
def forward(self, input, hx, adapt):
79+
"""Run aLSTM for one time step with given input and policy"""
80+
return alstm_cell(input, hx, adapt, self.weights, self.bias)
81+
82+
83+
class aLSTM(nn.Module):
84+
85+
def __init__(self, input_size, hidden_size, adapt_size, output_size=None,
86+
nlayers=1, dropout_hidden=None, dropout_adapt=None,
87+
batch_first=False, bias=True):
88+
super(aLSTM, self).__init__()
89+
self.input_size = input_size
90+
self.hidden_size = hidden_size
91+
self.adapt_size = adapt_size
92+
self.output_size = output_size if output_size else hidden_size
93+
self.nlayers = nlayers
94+
self.dropout_hidden = dropout_hidden
95+
self.dropout_adapt = dropout_adapt
96+
self.batch_first = batch_first
97+
self.bias = bias
98+
99+
psz, alyr, elyr, flyr = [], [], [], []
100+
for l in range(nlayers):
101+
if l == 0:
102+
ninp, nhid = input_size, hidden_size
103+
104+
if l == nlayers - 1:
105+
ninp, nhid = hidden_size, output_size
106+
if nlayers == 1:
107+
ninp, nhid = input_size, output_size
108+
109+
# policy latent variable
110+
ain = adapt_size + ninp + nhid if nlayers != 1 else ninp + nhid
111+
alyr.append(nn.LSTMCell(ain, adapt_size))
112+
113+
# sub-policy projection
114+
ipsz = ninp + nhid
115+
opsz = 8 * nhid if not bias else 12 * nhid
116+
psz.append(ipsz + opsz)
117+
elyr.append(Project(adapt_size, psz[-1]))
118+
119+
# aLSTM
120+
flyr.append(aLSTMCell(ninp, nhid, use_bias=bias))
121+
122+
self.adapt_layers = nn.ModuleList(alyr)
123+
self.project_layers = nn.ModuleList(elyr)
124+
self.alstm_layers = nn.ModuleList(flyr)
125+
self.policy_sizes = psz
126+
127+
def forward(self, input, hidden=None):
128+
"""run aLSTM over a batch of sequences."""
129+
if self.batch_first:
130+
input = input.transpose(0, 1)
131+
132+
if hidden is None:
133+
hidden = self.init_hidden(input.size(1))
134+
135+
hidden = convert(hidden, list)
136+
137+
adaptive_hidden, alstm_hidden = hidden
138+
139+
dropout = False
140+
if self.training and self.dropout:
141+
dropout = True
142+
lsz = [h[0].size() for h in alstm_hidden]
143+
asz = [h[0].size() for h in adaptive_hidden]
144+
dropout_alstm = VariationalDropout(
145+
input.data, self.dropout_hidden, lsz)
146+
dropout_adaptive = VariationalDropout(
147+
input.data, self.dropout_adaptive, asz)
148+
149+
output = []
150+
for x in input:
151+
for l in range(self.nlayers):
152+
alyr = self.adapt_layers[l]
153+
elyr = self.project_layers[l]
154+
flyr = self.alstm_layers[l]
155+
ahx, ahc = adaptive_hidden[l]
156+
fhx, fhc = alstm_hidden[l]
157+
158+
if self.nlayers != 1:
159+
ax = torch.cat([x, fhx, adaptive_hidden[l-1][0]], 1)
160+
else:
161+
ax = torch.cat([x, fhx], 1)
162+
163+
ahx, ahc = alyr(ax, (ahx, ahc))
164+
165+
if dropout:
166+
ahx = dropout_adaptive(ahx, l)
167+
ax = ahx
168+
else:
169+
ax = ahx
170+
171+
ahe = elyr(ax)
172+
fhx, fhc = flyr(x, (fhx, fhc), ahe)
173+
174+
if l == self.nlayers - 1:
175+
output.append(fhx)
176+
177+
if dropout:
178+
fhx = dropout_alstm(fhx, l)
179+
180+
adaptive_hidden[l] = [ahx, ahc]
181+
alstm_hidden[l] = [fhx, fhc]
182+
183+
x = fhx
184+
###
185+
###
186+
hidden = (adaptive_hidden, alstm_hidden)
187+
output = torch.stack(output, 1 if self.batch_first else 0)
188+
189+
hidden = convert(hidden, tuple)
190+
return output, hidden
191+
192+
def init_hidden(self, bsz):
193+
"""Utility for initializing hidden states (to zero)"""
194+
asz = self.adapt_size
195+
osz = self.output_size
196+
hsz = self.hidden_size
197+
weight = next(self.parameters()).data
198+
199+
def hidden(out):
200+
return Variable(weight.new(bsz, out).zero_())
201+
202+
ah = [(hidden(asz), hidden(asz)) for _ in range(self.nlayers)]
203+
fh = [(hidden(hsz if l != self.nlayers - 1 else osz),
204+
hidden(hsz if l != self.nlayers - 1 else osz))
205+
for l in range(self.nlayers)]
206+
return ah, fh

0 commit comments

Comments
 (0)