Skip to content

Commit be25ba1

Browse files
Shivangi MahtoShivangi Mahto
authored andcommitted
Multi-timescale LSTM LM ICLR
0 parents  commit be25ba1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+15803
-0
lines changed

Bins_vs_location_of_abl_chunk.png

62.9 KB
Loading

Bins_vs_timescale_PTB.png

62.9 KB
Loading

Bins_vs_timescale_Wiki.png

65.6 KB
Loading

Plot estimated timescale.ipynb

Lines changed: 346 additions & 0 deletions
Large diffs are not rendered by default.

ReadMe.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Codes for training a multi-timescale (MTS) language model.
2+
## Required dependencies: Python3.6 or above, Numpy, Scipy and Pytorch1.7.0 or above with CUDA version 10.1
3+
4+
## Example script: to train and evaluate a standard and MTS LM on PTB dataset:
5+
6+
### bash run.sh
7+
8+
## Detailed description:
9+
10+
### 1. To download PTB/WIKI data: bash getdata.sh
11+
12+
### 2. model_mts.py defines the multi-timescale language model.
13+
14+
### 3. To train a multi-timescale model, use train_mts.py as follows:
15+
16+
#### On PTB data
17+
18+
python train_mts.py --batch_size 20 --data data/penn --dropouti 0.4 --dropouth 0.25 --seed 141 --epoch 1000 --save train_mts.pt
19+
20+
#### On Wiki data
21+
22+
python train_mts.py --data data/wikitext-2 --dropouth 0.2 --seed 1882 --epoch 1000 --save train_mts.pt
23+
24+
### 4. To evaluate model on test set: including different word frequency bins and bootstrap test set
25+
26+
#### Trained LM on PTB data:
27+
python model_evaluation.py --model_name train_mts.pt --data data/penn/
28+
29+
#### Trained LM on Wiki data:
30+
python model_evaluation.py --model_name train_mts.pt --data data/wikitext-2/

data.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
import torch
3+
4+
from collections import Counter
5+
6+
7+
class Dictionary(object):
8+
def __init__(self):
9+
self.word2idx = {}
10+
self.idx2word = []
11+
self.counter = Counter()
12+
self.total = 0
13+
14+
def add_word(self, word):
15+
if word not in self.word2idx:
16+
self.idx2word.append(word)
17+
self.word2idx[word] = len(self.idx2word) - 1
18+
token_id = self.word2idx[word]
19+
self.counter[token_id] += 1
20+
self.total += 1
21+
return self.word2idx[word]
22+
23+
def __len__(self):
24+
return len(self.idx2word)
25+
26+
27+
class Corpus(object):
28+
def __init__(self, path):
29+
self.dictionary = Dictionary()
30+
self.train = self.tokenize(os.path.join(path, 'train.txt'))
31+
self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
32+
self.test = self.tokenize(os.path.join(path, 'test.txt'))
33+
34+
def tokenize(self, path):
35+
"""Tokenizes a text file."""
36+
assert os.path.exists(path)
37+
# Add words to the dictionary
38+
with open(path, 'r') as f:
39+
tokens = 0
40+
for line in f:
41+
words = line.split() + ['<eos>']
42+
tokens += len(words)
43+
for word in words:
44+
self.dictionary.add_word(word)
45+
46+
# Tokenize file content
47+
with open(path, 'r') as f:
48+
ids = torch.LongTensor(tokens)
49+
token = 0
50+
for line in f:
51+
words = line.split() + ['<eos>']
52+
for word in words:
53+
ids[token] = self.dictionary.word2idx[word]
54+
token += 1
55+
56+
return ids

embed_regularize.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import numpy as np
2+
3+
import torch
4+
5+
def embedded_dropout(embed, words, dropout=0.1, scale=None):
6+
if dropout:
7+
mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout)
8+
masked_embed_weight = mask * embed.weight
9+
else:
10+
masked_embed_weight = embed.weight
11+
if scale:
12+
masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight
13+
14+
padding_idx = embed.padding_idx
15+
if padding_idx is None:
16+
padding_idx = -1
17+
18+
X = torch.nn.functional.embedding(words, masked_embed_weight,
19+
padding_idx, embed.max_norm, embed.norm_type,
20+
embed.scale_grad_by_freq, embed.sparse
21+
)
22+
return X
23+
24+
if __name__ == '__main__':
25+
V = 50
26+
h = 4
27+
bptt = 10
28+
batch_size = 2
29+
30+
embed = torch.nn.Embedding(V, h)
31+
32+
words = np.random.random_integers(low=0, high=V-1, size=(batch_size, bptt))
33+
words = torch.LongTensor(words)
34+
35+
origX = embed(words)
36+
X = embedded_dropout(embed, words)
37+
38+
print(origX)
39+
print(X)

0 commit comments

Comments
 (0)