From ab41699a3f072b903ebe062517ddcca2c36fb0f8 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Wed, 18 Dec 2019 21:13:32 -0800 Subject: [PATCH] [v0.8.x][BUGFIX] Update BERT embedding script (#1045) * update embedding * display tokens used in the batch, and remove hard coded values * fix typo * update embedding display tokens used in the batch, and remove hard coded values fix typo * Update embedding.py * fix get_model call * Update embedding.py * Update embedding.py * fix download error in test * Update test_scripts.py --- scripts/bert/embedding.py | 99 ++++++++++++++++------------- scripts/bert/finetune_classifier.py | 2 +- scripts/bert/index.rst | 4 +- scripts/tests/test_scripts.py | 5 +- 4 files changed, 61 insertions(+), 49 deletions(-) diff --git a/scripts/bert/embedding.py b/scripts/bert/embedding.py index f680583a3a..0c11d739fc 100644 --- a/scripts/bert/embedding.py +++ b/scripts/bert/embedding.py @@ -1,5 +1,3 @@ -# coding: utf-8 - # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -29,7 +27,7 @@ from mxnet.gluon.data import DataLoader import gluonnlp -from gluonnlp.data import BERTTokenizer, BERTSentenceTransform +from gluonnlp.data import BERTTokenizer, BERTSentenceTransform, BERTSPTokenizer from gluonnlp.base import get_home_dir try: @@ -37,17 +35,6 @@ except ImportError: from .data.embedding import BertEmbeddingDataset -try: - unicode -except NameError: - # Define `unicode` for Python3 - def unicode(s, *_): - return s - - -def to_unicode(s): - return unicode(s, 'utf-8') - __all__ = ['BertEmbedding'] @@ -75,12 +62,14 @@ class BertEmbedding: max length of each sequence batch_size : int, default 256 batch size + sentencepiece : str, default None + Path to the sentencepiece .model file for both tokenization and vocab root : str, default '$MXNET_HOME/models' with MXNET_HOME defaults to '~/.mxnet' Location for keeping the model parameters. """ def __init__(self, ctx=mx.cpu(), dtype='float32', model='bert_12_768_12', dataset_name='book_corpus_wiki_en_uncased', params_path=None, - max_seq_length=25, batch_size=256, + max_seq_length=25, batch_size=256, sentencepiece=None, root=os.path.join(get_home_dir(), 'models')): self.ctx = ctx self.dtype = dtype @@ -88,23 +77,35 @@ def __init__(self, ctx=mx.cpu(), dtype='float32', model='bert_12_768_12', self.batch_size = batch_size self.dataset_name = dataset_name - # Don't download the pretrained models if we have a parameter path + # use sentencepiece vocab and a checkpoint + # we need to set dataset_name to None, otherwise it uses the downloaded vocab + if params_path and sentencepiece: + dataset_name = None + else: + dataset_name = self.dataset_name + if sentencepiece: + vocab = gluonnlp.vocab.BERTVocab.from_sentencepiece(sentencepiece) + else: + vocab = None self.bert, self.vocab = gluonnlp.model.get_model(model, - dataset_name=self.dataset_name, + dataset_name=dataset_name, pretrained=params_path is None, ctx=self.ctx, use_pooler=False, use_decoder=False, use_classifier=False, - root=root) - self.bert.cast(self.dtype) + root=root, vocab=vocab) + self.bert.cast(self.dtype) if params_path: logger.info('Loading params from %s', params_path) - self.bert.load_parameters(params_path, ctx=ctx, ignore_extra=True) + self.bert.load_parameters(params_path, ctx=ctx, ignore_extra=True, cast_dtype=True) lower = 'uncased' in self.dataset_name - self.tokenizer = BERTTokenizer(self.vocab, lower=lower) + if sentencepiece: + self.tokenizer = BERTSPTokenizer(sentencepiece, self.vocab, lower=lower) + else: + self.tokenizer = BERTTokenizer(self.vocab, lower=lower) self.transform = BERTSentenceTransform(tokenizer=self.tokenizer, max_seq_length=self.max_seq_length, pair=False) @@ -153,12 +154,9 @@ def oov(self, batches, oov_way='avg'): Parameters ---------- - batches : List[(tokens_id, - sequence_outputs, - pooled_output]. - batch token_ids (max_seq_length, ), - sequence_outputs (max_seq_length, dim, ), - pooled_output (dim, ) + batches : List[(tokens_id, sequence_outputs)]. + batch token_ids shape is (max_seq_length,), + sequence_outputs shape is (max_seq_length, dim) oov_way : str use **avg**, **sum** or **last** to get token embedding for those out of vocabulary words @@ -169,21 +167,29 @@ def oov(self, batches, oov_way='avg'): List of tokens, and tokens embedding """ sentences = [] + padding_idx, cls_idx, sep_idx = None, None, None + if self.vocab.padding_token: + padding_idx = self.vocab[self.vocab.padding_token] + if self.vocab.cls_token: + cls_idx = self.vocab[self.vocab.cls_token] + if self.vocab.sep_token: + sep_idx = self.vocab[self.vocab.sep_token] for token_ids, sequence_outputs in batches: tokens = [] tensors = [] oov_len = 1 for token_id, sequence_output in zip(token_ids, sequence_outputs): - if token_id == 1: - # [PAD] token, sequence is finished. + # [PAD] token, sequence is finished. + if padding_idx and token_id == padding_idx: break - if token_id in (2, 3): - # [CLS], [SEP] + # [CLS], [SEP] + if cls_idx and token_id == cls_idx: + continue + if sep_idx and token_id == sep_idx: continue token = self.vocab.idx_to_token[token_id] - if token.startswith('##'): - token = token[2:] - tokens[-1] += token + if not self.tokenizer.is_first_subword(token): + tokens.append(token) if oov_way == 'last': tensors[-1] = sequence_output else: @@ -212,19 +218,21 @@ def oov(self, batches, oov_way='avg'): parser.add_argument('--model', type=str, default='bert_12_768_12', help='pre-trained model') parser.add_argument('--dataset_name', type=str, default='book_corpus_wiki_en_uncased', - help='dataset') + help='name of the dataset used for pre-training') parser.add_argument('--params_path', type=str, default=None, help='path to a params file to load instead of the pretrained model.') - parser.add_argument('--max_seq_length', type=int, default=25, + parser.add_argument('--sentencepiece', type=str, default=None, + help='Path to the sentencepiece .model file for tokenization and vocab.') + parser.add_argument('--max_seq_length', type=int, default=128, help='max length of each sequence') parser.add_argument('--batch_size', type=int, default=256, help='batch size') parser.add_argument('--oov_way', type=str, default='avg', - help='how to handle oov\n' - 'avg: average all oov embeddings to represent the original token\n' - 'sum: sum all oov embeddings to represent the original token\n' - 'last: use last oov embeddings to represent the original token\n') - parser.add_argument('--sentences', type=to_unicode, nargs='+', default=None, + help='how to handle subword embeddings\n' + 'avg: average all subword embeddings to represent the original token\n' + 'sum: sum all subword embeddings to represent the original token\n' + 'last: use last subword embeddings to represent the original token\n') + parser.add_argument('--sentences', type=str, nargs='+', default=None, help='sentence for encoding') parser.add_argument('--file', type=str, default=None, help='file for encoding') @@ -240,7 +248,8 @@ def oov(self, batches, oov_way='avg'): else: context = mx.cpu() bert_embedding = BertEmbedding(ctx=context, model=args.model, dataset_name=args.dataset_name, - max_seq_length=args.max_seq_length, batch_size=args.batch_size) + max_seq_length=args.max_seq_length, batch_size=args.batch_size, + params_path=args.params_path, sentencepiece=args.sentencepiece) result = [] sents = [] if args.sentences: @@ -255,7 +264,7 @@ def oov(self, batches, oov_way='avg'): logger.error('Please specify --sentence or --file') if result: - for sent, embeddings in zip(sents, result): - print('Text: {}'.format(sent)) - _, tokens_embedding = embeddings + for _, embeddings in zip(sents, result): + sent, tokens_embedding = embeddings + print('Text: {}'.format(' '.join(sent))) print('Tokens embedding: {}'.format(tokens_embedding)) diff --git a/scripts/bert/finetune_classifier.py b/scripts/bert/finetune_classifier.py index 60d7fd05bd..b6d4e35522 100644 --- a/scripts/bert/finetune_classifier.py +++ b/scripts/bert/finetune_classifier.py @@ -219,7 +219,7 @@ except ImportError: # amp is not available logging.info('Mixed precision training with float16 requires MXNet >= ' - '1.5.0b20190627. Please consider upgrading your MXNet version.') + '1.5.1. Please consider upgrading your MXNet version.') exit() # model and loss diff --git a/scripts/bert/index.rst b/scripts/bert/index.rst index 4f8a47bee5..67e8d799f7 100644 --- a/scripts/bert/index.rst +++ b/scripts/bert/index.rst @@ -278,8 +278,8 @@ The goal of this BERT Embedding is to obtain the token embedding from BERT's pre .. code-block:: shell - python bert/embedding.py --sentences "GluonNLP is a toolkit that enables easy text preprocessing, datasets loading and neural models building to help you speed up your Natural Language Processing (NLP) research." - Text: GluonNLP is a toolkit that enables easy text preprocessing, datasets loading and neural models building to help you speed up your Natural Language Processing (NLP) research. + python embedding.py --sentences "GluonNLP is a toolkit that enables easy text preprocessing, datasets loading and neural models building to help you speed up your Natural Language Processing (NLP) research." + Text: g ##lu ##on ##nl ##p is a tool ##kit that enables easy text prep ##ro ##ces ##sing , data ##set ##s loading and neural models building to help you speed up your natural language processing ( nl ##p ) research . Tokens embedding: [array([-0.11881411, -0.59530115, 0.627092 , ..., 0.00648153, -0.03886228, 0.03406909], dtype=float32), array([-0.7995638 , -0.6540758 , -0.00521846, ..., -0.42272145, -0.5787281 , 0.7021201 ], dtype=float32), array([-0.7406778 , -0.80276626, 0.3931962 , ..., -0.49068323, diff --git a/scripts/tests/test_scripts.py b/scripts/tests/test_scripts.py index e82a14700c..e43b09fd48 100644 --- a/scripts/tests/test_scripts.py +++ b/scripts/tests/test_scripts.py @@ -25,6 +25,7 @@ import pytest import mxnet as mx +import gluonnlp as nlp @pytest.mark.serial @pytest.mark.remote_required @@ -200,8 +201,10 @@ def test_bert_embedding(use_pretrained): if use_pretrained: args.extend(['--dtype', 'float32']) else: + _, _ = nlp.model.get_model('bert_12_768_12', dataset_name='book_corpus_wiki_en_uncased', + pretrained=True, root='test_bert_embedding') args.extend(['--params_path', - '~/.mxnet/models/bert_12_768_12_book_corpus_wiki_en_uncased-75cc780f.params']) + 'test_bert_embedding/bert_12_768_12_book_corpus_wiki_en_uncased-75cc780f.params']) process = subprocess.check_call([sys.executable, './scripts/bert/embedding.py'] + args) time.sleep(5)