diff --git a/sequence_labeling/README.md b/sequence_labeling/README.md new file mode 100644 index 0000000000..401b5110cd --- /dev/null +++ b/sequence_labeling/README.md @@ -0,0 +1,47 @@ +# Sequence Labeling with BiRNN and CRF + +This example trains a bidirectional RNN (LSTM or GRU) with CRF for part-of-speech tagging task. It uses the UDPOS dataset + in `torchtext` and can be extended to other sequence labeling tasks by replacing the dataset accordingly. + +The `train.py` script accepts the following arguments: + +```bash +optional arguments: + -h, --help show this help message and exit + --emb_path EMB_PATH path to pretrained embeddings + --optimizer {sgd,adam,nesterov} + optimizer to use + --sgd_momentum SGD_MOMENTUM + momentum for stochastic gradient descent + --rnn {lstm,gru} RNN type + --rnn_dim RNN_DIM RNN hidden state dimension + --word_emb_dim WORD_EMB_DIM + word embedding dimension + --char_emb_dim CHAR_EMB_DIM + char embedding dimension + --char_rnn_dim CHAR_RNN_DIM + char RNN hidden state dimension + --word_min_freq WORD_MIN_FREQ + minimum frequency threshold in word vocabulary + --train_batch_size TRAIN_BATCH_SIZE + batch size for training phase + --val_batch_size VAL_BATCH_SIZE + batch size for evaluation phase + --early_stopping_patience EARLY_STOPPING_PATIENCE + early stopping patience + --dropout_before_rnn DROPOUT_BEFORE_RNN + dropout rate on RNN inputs + --dropout_after_rnn DROPOUT_AFTER_RNN + dropout rate on RNN outputs + --lr LR starting learning rate + --lr_decay LR_DECAY learning rate decay factor + --min_lr MIN_LR minimum learning rate + --lr_shrink LR_SHRINK + learning rate reducing factor + --lr_shrink_patience LR_SHRINK_PATIENCE + learning rate reducing patience + --crf {none,small,large} + CRF type or no CRF + --max_epochs MAX_EPOCHS + maximum training epochs +``` diff --git a/sequence_labeling/args.py b/sequence_labeling/args.py new file mode 100644 index 0000000000..f5c50f43ea --- /dev/null +++ b/sequence_labeling/args.py @@ -0,0 +1,30 @@ +import argparse +import torch + +parser = argparse.ArgumentParser() +parser.add_argument('--emb_path', default='.data/glove.6B.100d.txt', type=str, help='path to pretrained embeddings') +parser.add_argument('--optimizer', default='adam', type=str, choices=['sgd', 'adam', 'nesterov'], + help='optimizer to use') +parser.add_argument('--sgd_momentum', default=0.9, type=float, help='momentum for stochastic gradient descent') +parser.add_argument('--rnn', default='lstm', type=str, choices=['lstm', 'gru'], help='RNN type') +parser.add_argument('--rnn_dim', default=100, type=int, help='RNN hidden state dimension') +parser.add_argument('--word_emb_dim', default=100, type=int, help='word embedding dimension') +parser.add_argument('--char_emb_dim', default=50, type=int, help='char embedding dimension') +parser.add_argument('--char_rnn_dim', default=50, type=int, help='char RNN hidden state dimension') +parser.add_argument('--word_min_freq', default=2, type=int, help='minimum frequency threshold in word vocabulary') +parser.add_argument('--train_batch_size', default=16, type=int, help='batch size for training phase') +parser.add_argument('--val_batch_size', default=64, type=int, help='batch size for evaluation phase') +parser.add_argument('--early_stopping_patience', default=30, type=int, help='early stopping patience') +parser.add_argument('--dropout_before_rnn', default=0.5, type=float, help='dropout rate on RNN inputs') +parser.add_argument('--dropout_after_rnn', default=0.5, type=float, help='dropout rate on RNN outputs') +parser.add_argument('--lr', default=0.01, type=float, help='starting learning rate') +parser.add_argument('--lr_decay', default=0.001, type=float, help='learning rate decay factor') +parser.add_argument('--min_lr', default=0.0005, type=float, help='minimum learning rate') +parser.add_argument('--lr_shrink', default=0.5, type=float, help='learning rate reducing factor') +parser.add_argument('--lr_shrink_patience', default=0, type=float, help='learning rate reducing patience') +parser.add_argument('--crf', default='small', type=str, choices=['none', 'small', 'large'], help='CRF type or no CRF') +parser.add_argument('--max_epochs', default=300, type=int, help='maximum training epochs') + +args = parser.parse_args() + +args.device = 'cuda:0' if torch.cuda.is_available() else 'cpu' diff --git a/sequence_labeling/model.py b/sequence_labeling/model.py new file mode 100644 index 0000000000..3c036f5981 --- /dev/null +++ b/sequence_labeling/model.py @@ -0,0 +1,282 @@ +import codecs + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def get_rnn(input_dim, hidden_dim, rnn_type): + if rnn_type == 'gru': + return nn.GRU(input_dim, hidden_dim, bidirectional=True) + elif rnn_type == 'lstm': + return nn.LSTM(input_dim, hidden_dim, bidirectional=True) + else: + raise Exception('Unknown RNN type') + + +class SequenceLabelingModel(nn.Module): + def __init__(self, args, logger): + super(SequenceLabelingModel, self).__init__() + + self.args = args + self.logger = logger + + self.word_emb = WordEmbedding(args, logger) + self.word_emb.requires_grad = False + + if args.char_rnn_dim > 0 and args.char_emb_dim > 0: + self.char_emb = nn.Embedding(len(args.char2idx), args.char_emb_dim) + self.char_rnn = get_rnn(args.char_emb_dim, args.char_rnn_dim, args.rnn) + self.rnn = get_rnn(args.word_emb_dim + args.char_rnn_dim * 2, args.rnn_dim, args.rnn) + else: + self.rnn = get_rnn(args.word_emb_dim, args.rnn_dim, args.rnn) + self.char_emb = None + self.char_rnn = None + + self.dropout_before_rnn = nn.Dropout(args.dropout_before_rnn) + self.dropout_after_rnn = nn.Dropout(args.dropout_after_rnn) + + if args.crf in {'small', 'large'}: + self.loss = CRF(args.rnn_dim * 2, len(args.tag2idx), args.tag2idx[args.tag_bos], args.tag2idx[args.tag_eos], + args.tag2idx[args.tag_pad], large_model=args.crf == 'large') + else: + self.loss = CELoss(args.rnn_dim * 2, len(args.tag2idx)) + + def _get_char_emb(self, chars): + """ + :param chars: batch_size x seq_len x word_len + :return: seq_len x batch_size x char_emb_dim + """ + chars, chars_lens = chars[0], chars[2] + + batch_size, seq_len, word_len = chars.size() + n = seq_len * batch_size + chars = chars.permute(2, 1, 0).contiguous().view(word_len, n) + chars_lens = chars_lens.t().contiguous().view(1, n).expand(word_len, n) + mask = torch.range(0, n - 1).long().view(1, n).expand(word_len, n).cuda() < chars_lens + + char_embeds = self.char_rnn(self.char_emb(chars))[0] + mask = mask.view(word_len, n, 1).expand_as(char_embeds) + return (char_embeds * mask.float()).max(0)[0].view(seq_len, batch_size, -1) + + def _get_rnn_features(self, rnn, x, lengths): + """ + :param x: seq_len x batch_size x emb_dim + :param lengths: batch_size x 1 + :return: seq_len x batch_size x hidden_dim * 2 + """ + lengths, idx_sort = lengths.sort(descending=True) + _, idx_unsort = idx_sort.sort(descending=False) + emb = x.index_select(1, idx_sort) + emb_packed = nn.utils.rnn.pack_padded_sequence(emb, lengths.tolist()) + out = rnn(emb_packed)[0] + out = nn.utils.rnn.pad_packed_sequence(out)[0] + return out.index_select(1, idx_unsort) + + def _get_features(self, batch): + """ + batch.word[0]: seq_len x batch_size + batch.word[1]: batch_size + batch.char[0]: batch_size x seq_len x word_len + :return: seq_len x batch_size x rnn_dim * 2 + """ + words, seq_lens = batch.word[0], batch.word[1] + embeds = self.word_emb.forward(words) + + if self.args.char_rnn_dim > 0 and self.args.char_emb_dim > 0: + char_embeds = self._get_char_emb(batch.char) + embeds = torch.cat([embeds, char_embeds], 2) + + embeds = self.dropout_before_rnn(embeds) + features = self._get_rnn_features(self.rnn, embeds, seq_lens) + return self.dropout_after_rnn(features) + + def forward(self, batch): + """ + batch.word[0]: seq_len x batch_size + batch.word[1]: batch_size + batch.char[0]: batch_size x seq_len x word_len + :return: scalar tensor + """ + seq_lens = batch.word[1] + features = self._get_features(batch) + return self.loss.forward(features, batch.tag, seq_lens) + + def decode(self, batch): + """ + batch.word[0]: seq_len x batch_size + batch.word[1]: batch_size + batch.char[0]: batch_size x seq_len x word_len + :return: seq_len x batch_size + """ + seq_lens = batch.word[1] + features = self._get_features(batch) + return self.loss.decode(features, seq_lens) + + +class WordEmbedding(nn.Module): + def __init__(self, args, logger): + super(WordEmbedding, self).__init__() + + self.lut = nn.Embedding(len(args.word2idx), args.word_emb_dim) + self.lut.weight.data.uniform_(-0.1, 0.1) + + logger.info('Loading word embeds') + word_embeds = {} + for line in codecs.open(args.emb_path, 'r', 'utf-8'): + line = line.strip().split() + if len(line) != args.word_emb_dim + 1: + continue + word_embeds[line[0]] = torch.Tensor([float(i) for i in line[1:]]) + + logger.info('Matching word embeds') + count_raw, count_lower = 0, 0 + for word, idx in args.word2idx.items(): + if word in word_embeds: + self.lut.weight.data[idx].copy_(word_embeds[word]) + count_raw += 1 + elif word.lower() in word_embeds: + self.lut.weight.data[idx].copy_(word_embeds[word.lower()]) + count_lower += 1 + + logger.info('Coverage %.4f (%d+%d/%d, raw+lower)' % (float(count_raw + count_lower) / len(args.word2idx), + count_raw, count_lower, len(args.word2idx))) + + def forward(self, words): + """ + :param words: seq_len x batch_size + :return: seq_len x batch_size x emb_dim + """ + return self.lut.forward(words) + + +class CRF(nn.Module): + def __init__(self, feature_dim, tags_num, bos_ix, eos_ix, pad_ix, large_model=False): + super(CRF, self).__init__() + self.tags_num = tags_num + self.large_model = large_model + self.bos_ix, self.eos_ix, self.pad_ix = bos_ix, eos_ix, pad_ix + + if self.large_model: + self.feat2tag = nn.Linear(feature_dim, self.tags_num * self.tags_num) + else: + self.feat2tag = nn.Linear(feature_dim, self.tags_num) + self.transitions = nn.Parameter(torch.zeros(self.tags_num, self.tags_num)) + + def _get_crf_scores(self, features): + """ + :param features: seq_len x batch_size x feature_dim + :return: seq_len x batch_size x tags_num x tags_num + """ + s_len, b_size, n_tags = features.size(0), features.size(1), self.tags_num + if self.large_model: + return self.feat2tag(features).view(s_len, b_size, n_tags, n_tags) + else: + emit_scores = self.feat2tag(features).view(s_len, b_size, n_tags, 1).expand(s_len, b_size, n_tags, n_tags) + transition_scores = self.transitions.view(1, 1, n_tags, n_tags).expand(s_len, b_size, n_tags, n_tags) + return emit_scores + transition_scores + + def _get_gold_scores(self, crf_scores, tags, seq_lens): + """ + :param crf_scores: seq_len x batch_size x tags_num x tags_num + :param tags: seq_len x batch_size + :param seq_lens: batch_size + :return: scalar tensor + """ + s_len, b_size = crf_scores.size(0), crf_scores.size(1) + pad_tags = torch.Tensor(1, b_size).long().fill_(self.pad_ix).cuda() + bigram_tags = self.tags_num * tags + torch.cat([tags[1:, :], pad_tags], 0) + + gold_score = crf_scores.view(s_len, b_size, -1).gather(2, bigram_tags.view(s_len, b_size, 1)).squeeze(2) + gold_score = gold_score.cumsum(0).gather(0, seq_lens.view(1, b_size) - 1).sum() + return gold_score + + def forward(self, features, tags, seq_lens): + """ + :param features: seq_len x batch_size x feature_dim + :param tags: seq_len x batch_size + :param seq_lens: batch_size + :return: scalar tensor + """ + s_len, b_size, n_tags = features.size(0), features.size(1), self.tags_num + crf_scores = self._get_crf_scores(features) + gold_scores = self._get_gold_scores(crf_scores, tags, seq_lens) + + cur_score = crf_scores[0, :, self.bos_ix, :].contiguous() + for idx in range(1, s_len): + next_score = cur_score.view(b_size, n_tags, 1).expand(b_size, n_tags, n_tags) + crf_scores[idx, :, :, :] + next_score = self.log_sum_exp(next_score) + mask = (idx < seq_lens).view(b_size, 1).expand_as(next_score).float() + cur_score = mask * next_score + (1 - mask) * cur_score + cur_score = cur_score[:, self.eos_ix].sum() + + return cur_score - gold_scores + + def decode(self, features, seq_lens): + """ + :param features: seq_len x batch_size x feature_dim + :param seq_lens: batch_size + :return: seq_len x batch_size + """ + s_len, b_size, n_tags = features.size(0), features.size(1), self.tags_num + crf_scores = self._get_crf_scores(features) + + cur_score = crf_scores[0, :, self.bos_ix, :].contiguous() + back_pointers = [] + for idx in range(1, s_len): + next_score = cur_score.view(b_size, n_tags, 1).expand(b_size, n_tags, n_tags) + crf_scores[idx, :, :, :] + cur_score, cur_ptr = next_score.max(1) + cur_ptr.masked_fill_((idx >= seq_lens).view(b_size, 1).expand_as(cur_ptr), self.eos_ix) + back_pointers.append(cur_ptr) + + best_seq = torch.Tensor(s_len, b_size).long().fill_(self.eos_ix).cuda() + best_seq[0, :] = self.bos_ix + ptr = torch.Tensor(b_size, 1).long().fill_(self.eos_ix).cuda() + for idx in range(s_len - 2, -1, -1): + ptr = back_pointers[idx].gather(1, ptr) + best_seq[idx+1, :] = ptr.squeeze(1) + return best_seq + + @staticmethod + def log_sum_exp(v): + """ + :param v: (p, q, r) Variable + :return: (p, r) Variable + """ + p, r = v.size(0), v.size(2) + max_v = v.max(1)[0] + return max_v + (v - max_v.view(p, 1, r).expand_as(v)).exp().sum(1).log() + + +class CELoss(nn.Module): + def __init__(self, feature_dim, tags_num): + super(CELoss, self).__init__() + self.hidden2tag = nn.Linear(feature_dim, tags_num) + self.tags_num = tags_num + + def _get_scores(self, features): + """ + :param features: seq_len x batch_size x feature_dim + :return: seq_len x batch_size x tags_num + """ + return F.log_softmax(self.hidden2tag(features), dim=2) + + def forward(self, features, tags, seq_lens): + """ + :param features: seq_len x batch_size x feature_dim + :param tags: seq_len x batch_size + :param seq_lens: batch_size + :return: scalar tensor + """ + s_len, b_size = features.size(0), features.size(1) + loss = self._get_scores(features).gather(2, tags.view(s_len, b_size, 1)).squeeze(2) + loss = loss.cumsum(0).gather(0, seq_lens.view(1, b_size) - 1).sum() + return -loss + + def decode(self, features, seq_lens): + """ + :param features: seq_len x batch_size x feature_dim + :param seq_lens: batch_size + :return: seq_len x batch_size + """ + return self._get_scores(features).max(2)[1] diff --git a/sequence_labeling/train.py b/sequence_labeling/train.py new file mode 100644 index 0000000000..19af3129e0 --- /dev/null +++ b/sequence_labeling/train.py @@ -0,0 +1,130 @@ +import sys +from tqdm import tqdm +import logging + +import torch +from torch.optim import lr_scheduler +from torchtext import datasets, data + +from model import SequenceLabelingModel +from args import args + +logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +def get_data_iter(): + WORD = data.Field(init_token='', eos_token='', include_lengths=True) + UD_TAG = data.Field(init_token='', eos_token='') + PTB_TAG = data.Field(init_token='', eos_token='') + CHAR_NESTING = data.Field(tokenize=list, init_token='', eos_token='') + CHAR = data.NestedField(CHAR_NESTING, init_token='', eos_token='', include_lengths=True) + + train, val, test = datasets.UDPOS.splits( + fields=((('word', 'char'), (WORD, CHAR)), ('tag', UD_TAG), ('ptbtag', PTB_TAG)), + root='.data', + train='en-ud-tag.v2.train.txt', + validation='en-ud-tag.v2.dev.txt', + test='en-ud-tag.v2.test.txt' + ) + + WORD.build_vocab(train, min_freq=args.word_min_freq) + UD_TAG.build_vocab(train) + PTB_TAG.build_vocab(train) + CHAR.build_vocab(train) + + args.word2idx = WORD.vocab.stoi + args.tag2idx = PTB_TAG.vocab.stoi + args.char2idx = CHAR.vocab.stoi + args.tag_bos = PTB_TAG.init_token + args.tag_eos = PTB_TAG.eos_token + args.tag_pad = PTB_TAG.pad_token + + train_iter, val_iter, test_iter = data.BucketIterator.splits((train, val, test), + batch_sizes=(args.train_batch_size, args.val_batch_size, args.val_batch_size), + device=args.device, repeat=False) + + return train_iter, val_iter, test_iter + + +def get_optimizer_scheduler(model): + if args.optimizer == 'sgd': + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, + nesterov=args.optimizer == 'nesterov') + elif args.optimizer == 'adam': + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + else: + raise Exception('Unknown optimizer specified') + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, min_lr=args.min_lr, factor=args.lr_shrink, + patience=args.lr_shrink_patience, mode='max') + return optimizer, scheduler + + +def get_lr(optimizer): + for param_group in optimizer.param_groups: + return param_group['lr'] + + +class EarlyStoppingCriterion(object): + def __init__(self, patience): + self.count = 0 + self.patience = patience + + def step(self, improved): + self.count = 0 if improved else self.count + 1 + return self.count <= self.patience + + +def evaluate(model, data_iter, split): + model.eval() + total, acc = 0, 0 + data_iter.init_epoch() + for batch in data_iter: + predictions = model.decode(batch) + for i in range(batch.batch_size): + total += batch.word[1][i].item() + for j in range(batch.word[1][i]): + acc += (predictions[j, i] == batch.tag[j, i]).item() + acc = float(acc) / total + logger.info('%s acc: %.8f' % (split, acc)) + return acc + + +def train(): + train_iter, val_iter, test_iter = get_data_iter() + model = SequenceLabelingModel(args, logger).cuda() + optimizer, scheduler = get_optimizer_scheduler(model) + early_stopping_criterion = EarlyStoppingCriterion(patience=args.early_stopping_patience) + + logger.info('Start training') + + for epoch in range(args.max_epochs): + cur_lr = get_lr(optimizer) + logger.info('Epoch %d, lr %.6f' % (epoch, cur_lr)) + + model.train() + train_score = [] + batch_num = len(train_iter) + cur_num = 0 + train_iter.init_epoch() + progress = tqdm(train_iter, mininterval=2, leave=False, file=sys.stdout) + for i, batch in enumerate(progress): + optimizer.zero_grad() + + batch_score = model.forward(batch) + train_score.append(batch_score.item()) + cur_num += batch.batch_size + progress.set_description(desc='%d/%d, train loss %.4f' % (i, batch_num, sum(train_score) / cur_num)) + batch_score.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) + optimizer.step() + + val_score = evaluate(model, val_iter, 'val') + test_score = evaluate(model, test_iter, 'test') + if not early_stopping_criterion.step(val_score): + break + scheduler.step(val_score) + + +if __name__ == '__main__': + train()