Skip to content

Commit 892167a

Browse files
committed
First commit
1 parent d6a05f7 commit 892167a

28 files changed

+1960
-0
lines changed

config.ini

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
[Data]
2+
bert_model = 'bert-base-cased'
3+
4+
[Network]
5+
n_embed = 100
6+
n_char_embed = 50
7+
n_feat_embed = 100
8+
n_bert_layers = 4
9+
embed_dropout = .33
10+
n_lstm_hidden = 400
11+
n_lstm_layers = 3
12+
lstm_dropout = .33
13+
n_mlp_arc = 500
14+
n_mlp_rel = 100
15+
mlp_dropout = .33
16+
17+
[Optimizer]
18+
lr = 2e-3
19+
mu = .9
20+
nu = .9
21+
epsilon = 1e-12
22+
clip = 5.0
23+
decay = .75
24+
decay_steps = 5000
25+
26+
[Run]
27+
batch_size = 5000
28+
epochs = 50000
29+
patience = 100
30+
min_freq = 2
31+
fix_len = 20

parser/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from .model import Model
4+
5+
__all__ = ['Model']

parser/cmds/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from .evaluate import Evaluate
4+
from .predict import Predict
5+
from .train import Train
6+
7+
__all__ = ['Evaluate', 'Predict', 'Train']

parser/cmds/cmd.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import os
4+
from parser.utils import Embedding
5+
from parser.utils.alg import crf, eisner
6+
from parser.utils.common import bos, pad, unk
7+
from parser.utils.corpus import CoNLL, Corpus
8+
from parser.utils.field import BertField, CharField, Field
9+
from parser.utils.fn import ispunct, istree, numericalize_arcs
10+
from parser.utils.metric import Metric
11+
12+
import torch
13+
import torch.nn as nn
14+
from transformers import BertTokenizer
15+
16+
17+
class CMD(object):
18+
19+
def __call__(self, args):
20+
self.args = args
21+
if not os.path.exists(args.file):
22+
os.mkdir(args.file)
23+
if not os.path.exists(args.fields) or args.preprocess:
24+
print("Preprocess the data")
25+
self.WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True)
26+
if args.feat == 'char':
27+
self.FEAT = CharField('chars', pad=pad, unk=unk, bos=bos,
28+
fix_len=args.fix_len, tokenize=list)
29+
elif args.feat == 'bert':
30+
tokenizer = BertTokenizer.from_pretrained(args.bert_model)
31+
self.FEAT = BertField('bert', pad='[PAD]', bos='[CLS]',
32+
tokenize=tokenizer.encode)
33+
else:
34+
self.FEAT = Field('tags', bos=bos)
35+
self.ARC = Field('arcs', bos=bos, use_vocab=False,
36+
fn=numericalize_arcs)
37+
self.REL = Field('rels', bos=bos)
38+
if args.feat in ('char', 'bert'):
39+
self.fields = CoNLL(FORM=(self.WORD, self.FEAT),
40+
HEAD=self.ARC, DEPREL=self.REL)
41+
else:
42+
self.fields = CoNLL(FORM=self.WORD, CPOS=self.FEAT,
43+
HEAD=self.ARC, DEPREL=self.REL)
44+
45+
train = Corpus.load(args.ftrain, self.fields,
46+
args.max_len, args.proj, args.parts)
47+
if args.fembed:
48+
embed = Embedding.load(args.fembed, args.unk)
49+
else:
50+
embed = None
51+
self.WORD.build(train, args.min_freq, embed)
52+
self.FEAT.build(train)
53+
self.REL.build(train)
54+
torch.save(self.fields, args.fields)
55+
else:
56+
self.fields = torch.load(args.fields)
57+
if args.feat in ('char', 'bert'):
58+
self.WORD, self.FEAT = self.fields.FORM
59+
else:
60+
self.WORD, self.FEAT = self.fields.FORM, self.fields.CPOS
61+
self.ARC, self.REL = self.fields.HEAD, self.fields.DEPREL
62+
self.puncts = torch.tensor([i for s, i in self.WORD.vocab.stoi.items()
63+
if ispunct(s)]).to(args.device)
64+
self.criterion = nn.CrossEntropyLoss(reduction='sum')
65+
66+
print(f"{self.WORD}\n{self.FEAT}\n{self.ARC}\n{self.REL}")
67+
args.update({
68+
'n_words': self.WORD.vocab.n_init,
69+
'n_feats': len(self.FEAT.vocab),
70+
'n_rels': len(self.REL.vocab),
71+
'pad_index': self.WORD.pad_index,
72+
'unk_index': self.WORD.unk_index,
73+
'bos_index': self.WORD.bos_index
74+
})
75+
76+
def train(self, loader):
77+
self.model.train()
78+
79+
total_loss, metric = 0, Metric()
80+
81+
for words, feats, arcs, rels in loader:
82+
self.optimizer.zero_grad()
83+
84+
mask = words.ne(self.args.pad_index)
85+
# ignore the first token of each sentence
86+
mask[:, 0] = 0
87+
arc_scores, rel_scores = self.model(words, feats)
88+
loss, arc_scores = self.get_loss(arc_scores, rel_scores,
89+
arcs, rels, mask)
90+
loss.backward()
91+
nn.utils.clip_grad_norm_(self.model.parameters(),
92+
self.args.clip)
93+
self.optimizer.step()
94+
self.scheduler.step()
95+
96+
arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask)
97+
if self.args.partial:
98+
mask &= arcs.ge(0)
99+
# ignore all punctuation if not specified
100+
if not self.args.punct:
101+
mask &= words.unsqueeze(-1).ne(self.puncts).all(-1)
102+
total_loss += loss.item()
103+
metric(arc_preds, rel_preds, arcs, rels, mask)
104+
total_loss /= len(loader)
105+
106+
return total_loss, metric
107+
108+
@torch.no_grad()
109+
def evaluate(self, loader):
110+
self.model.eval()
111+
112+
total_loss, metric = 0, Metric()
113+
114+
for words, feats, arcs, rels in loader:
115+
mask = words.ne(self.args.pad_index)
116+
# ignore the first token of each sentence
117+
mask[:, 0] = 0
118+
arc_scores, rel_scores = self.model(words, feats)
119+
loss, arc_scores = self.get_loss(arc_scores, rel_scores,
120+
arcs, rels, mask)
121+
arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask)
122+
if self.args.partial:
123+
mask &= arcs.ge(0)
124+
# ignore all punctuation if not specified
125+
if not self.args.punct:
126+
mask &= words.unsqueeze(-1).ne(self.puncts).all(-1)
127+
total_loss += loss.item()
128+
metric(arc_preds, rel_preds, arcs, rels, mask)
129+
total_loss /= len(loader)
130+
131+
return total_loss, metric
132+
133+
@torch.no_grad()
134+
def predict(self, loader):
135+
self.model.eval()
136+
137+
all_arcs, all_rels, all_probs = [], [], []
138+
for words, feats in loader:
139+
mask = words.ne(self.args.pad_index)
140+
# ignore the first token of each sentence
141+
mask[:, 0] = 0
142+
lens = mask.sum(1).tolist()
143+
arc_scores, rel_scores = self.model(words, feats)
144+
if self.args.marg:
145+
arc_scores = crf(arc_scores, mask)
146+
arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask)
147+
all_arcs.extend(arc_preds[mask].split(lens))
148+
all_rels.extend(rel_preds[mask].split(lens))
149+
if self.args.prob:
150+
arc_probs = arc_scores.gather(-1, arc_preds.unsqueeze(-1))
151+
all_probs.extend(arc_probs.squeeze(-1)[mask].split(lens))
152+
all_arcs = [seq.tolist() for seq in all_arcs]
153+
all_rels = [self.REL.vocab.id2token(seq.tolist()) for seq in all_rels]
154+
all_probs = [[round(p, 4) for p in seq.tolist()] for seq in all_probs]
155+
156+
return all_arcs, all_rels, all_probs
157+
158+
def get_loss(self, arc_scores, rel_scores, arcs, rels, mask):
159+
total = mask.sum()
160+
batch_size, seq_len = mask.shape
161+
arc_loss, arc_probs = crf(arc_scores, mask, arcs,
162+
self.args.partial)
163+
if self.args.partial:
164+
mask = mask & arcs.ge(0)
165+
rel_scores, rels = rel_scores[mask], rels[mask]
166+
rel_scores = rel_scores[torch.arange(len(rels)), arcs[mask]]
167+
rel_loss = self.criterion(rel_scores, rels)
168+
loss = (arc_loss + rel_loss) / total
169+
return loss, arc_probs
170+
171+
def decode(self, arc_scores, rel_scores, mask):
172+
lens = mask.sum(1)
173+
# prevent self-loops
174+
arc_scores.diagonal(0, 1, 2).fill_(float('-inf'))
175+
arc_preds = arc_scores.argmax(-1)
176+
bad = [not istree(sequence[:l+1], self.args.proj)
177+
for l, sequence in zip(lens.tolist(), arc_preds.tolist())]
178+
if self.args.tree and any(bad):
179+
arc_preds[bad] = eisner(arc_scores[bad], mask[bad])
180+
rel_preds = rel_scores.argmax(-1)
181+
rel_preds = rel_preds.gather(-1, arc_preds.unsqueeze(-1)).squeeze(-1)
182+
183+
return arc_preds, rel_preds

parser/cmds/evaluate.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from datetime import datetime
4+
from parser import Model
5+
from parser.cmds.cmd import CMD
6+
from parser.utils.corpus import Corpus
7+
from parser.utils.data import TextDataset, batchify
8+
9+
10+
class Evaluate(CMD):
11+
12+
def add_subparser(self, name, parser):
13+
subparser = parser.add_parser(
14+
name, help='Evaluate the specified model and dataset.'
15+
)
16+
subparser.add_argument('--punct', action='store_true',
17+
help='whether to include punctuation')
18+
subparser.add_argument('--proj', action='store_true',
19+
help='whether to projectivise the outputs')
20+
subparser.add_argument('--fdata', default='data/ptb/test.conllx',
21+
help='path to dataset')
22+
23+
return subparser
24+
25+
def __call__(self, args):
26+
super(Evaluate, self).__call__(args)
27+
28+
print("Load the dataset")
29+
corpus = Corpus.load(args.fdata, self.fields)
30+
dataset = TextDataset(corpus, self.fields, args.buckets)
31+
# set the data loader
32+
dataset.loader = batchify(dataset, args.batch_size)
33+
print(f"{len(dataset)} sentences, "
34+
f"{len(dataset.loader)} batches, "
35+
f"{len(dataset.buckets)} buckets")
36+
37+
print("Load the model")
38+
self.model = Model.load(args.model)
39+
print(f"{self.model}\n")
40+
41+
print("Evaluate the dataset")
42+
start = datetime.now()
43+
loss, metric = self.evaluate(dataset.loader)
44+
total_time = datetime.now() - start
45+
print(f"Loss: {loss:.4f} {metric}")
46+
print(f"{total_time}s elapsed, "
47+
f"{len(dataset) / total_time.total_seconds():.2f} Sents/s")

parser/cmds/predict.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from datetime import datetime
4+
from parser import Model
5+
from parser.cmds.cmd import CMD
6+
from parser.utils.corpus import Corpus
7+
from parser.utils.data import TextDataset, batchify
8+
from parser.utils.field import Field
9+
10+
import torch
11+
12+
13+
class Predict(CMD):
14+
15+
def add_subparser(self, name, parser):
16+
subparser = parser.add_parser(
17+
name, help='Use a trained model to make predictions.'
18+
)
19+
subparser.add_argument('--prob', action='store_true',
20+
help='whether to output probs')
21+
subparser.add_argument('--marg', action='store_true',
22+
help='whether to use marginal probs')
23+
subparser.add_argument('--proj', action='store_true',
24+
help='whether to projectivise the outputs')
25+
subparser.add_argument('--fdata', default='data/ptb/test.conllx',
26+
help='path to dataset')
27+
subparser.add_argument('--fpred', default='pred.conllx',
28+
help='path to predicted result')
29+
30+
return subparser
31+
32+
def __call__(self, args):
33+
super(Predict, self).__call__(args)
34+
35+
print("Load the dataset")
36+
if args.prob:
37+
self.fields = self.fields._replace(PHEAD=Field('probs'))
38+
corpus = Corpus.load(args.fdata, self.fields)
39+
dataset = TextDataset(corpus, [self.WORD, self.FEAT], args.buckets)
40+
# set the data loader
41+
dataset.loader = batchify(dataset, args.batch_size)
42+
print(f"{len(dataset)} sentences, "
43+
f"{len(dataset.loader)} batches")
44+
45+
print("Load the model")
46+
self.model = Model.load(args.model)
47+
print(f"{self.model}\n")
48+
49+
print("Make predictions on the dataset")
50+
start = datetime.now()
51+
pred_arcs, pred_rels, pred_probs = self.predict(dataset.loader)
52+
total_time = datetime.now() - start
53+
# restore the order of sentences in the buckets
54+
indices = torch.tensor([i for bucket in dataset.buckets.values()
55+
for i in bucket]).argsort()
56+
corpus.arcs = [pred_arcs[i] for i in indices]
57+
corpus.rels = [pred_rels[i] for i in indices]
58+
if args.prob:
59+
corpus.probs = [pred_probs[i] for i in indices]
60+
print(f"Save the predicted result to {args.fpred}")
61+
corpus.save(args.fpred)
62+
print(f"{total_time}s elapsed, "
63+
f"{len(dataset) / total_time.total_seconds():.2f} Sents/s")

0 commit comments

Comments
 (0)