Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[v0.8.x][BUGFIX] Update BERT embedding script (#1045)
Browse files Browse the repository at this point in the history
* update embedding

* display tokens used in the batch, and remove hard coded values

* fix typo

* update embedding

display tokens used in the batch, and remove hard coded values

fix typo

* Update embedding.py

* fix get_model call

* Update embedding.py

* Update embedding.py

* fix download error in test

* Update test_scripts.py
  • Loading branch information
eric-haibin-lin authored and szha committed Dec 19, 2019
1 parent b5ded8f commit ab41699
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 49 deletions.
99 changes: 54 additions & 45 deletions scripts/bert/embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# coding: utf-8

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
Expand Down Expand Up @@ -29,25 +27,14 @@
from mxnet.gluon.data import DataLoader

import gluonnlp
from gluonnlp.data import BERTTokenizer, BERTSentenceTransform
from gluonnlp.data import BERTTokenizer, BERTSentenceTransform, BERTSPTokenizer
from gluonnlp.base import get_home_dir

try:
from data.embedding import BertEmbeddingDataset
except ImportError:
from .data.embedding import BertEmbeddingDataset

try:
unicode
except NameError:
# Define `unicode` for Python3
def unicode(s, *_):
return s


def to_unicode(s):
return unicode(s, 'utf-8')


__all__ = ['BertEmbedding']

Expand Down Expand Up @@ -75,36 +62,50 @@ class BertEmbedding:
max length of each sequence
batch_size : int, default 256
batch size
sentencepiece : str, default None
Path to the sentencepiece .model file for both tokenization and vocab
root : str, default '$MXNET_HOME/models' with MXNET_HOME defaults to '~/.mxnet'
Location for keeping the model parameters.
"""
def __init__(self, ctx=mx.cpu(), dtype='float32', model='bert_12_768_12',
dataset_name='book_corpus_wiki_en_uncased', params_path=None,
max_seq_length=25, batch_size=256,
max_seq_length=25, batch_size=256, sentencepiece=None,
root=os.path.join(get_home_dir(), 'models')):
self.ctx = ctx
self.dtype = dtype
self.max_seq_length = max_seq_length
self.batch_size = batch_size
self.dataset_name = dataset_name

# Don't download the pretrained models if we have a parameter path
# use sentencepiece vocab and a checkpoint
# we need to set dataset_name to None, otherwise it uses the downloaded vocab
if params_path and sentencepiece:
dataset_name = None
else:
dataset_name = self.dataset_name
if sentencepiece:
vocab = gluonnlp.vocab.BERTVocab.from_sentencepiece(sentencepiece)
else:
vocab = None
self.bert, self.vocab = gluonnlp.model.get_model(model,
dataset_name=self.dataset_name,
dataset_name=dataset_name,
pretrained=params_path is None,
ctx=self.ctx,
use_pooler=False,
use_decoder=False,
use_classifier=False,
root=root)
self.bert.cast(self.dtype)
root=root, vocab=vocab)

self.bert.cast(self.dtype)
if params_path:
logger.info('Loading params from %s', params_path)
self.bert.load_parameters(params_path, ctx=ctx, ignore_extra=True)
self.bert.load_parameters(params_path, ctx=ctx, ignore_extra=True, cast_dtype=True)

lower = 'uncased' in self.dataset_name
self.tokenizer = BERTTokenizer(self.vocab, lower=lower)
if sentencepiece:
self.tokenizer = BERTSPTokenizer(sentencepiece, self.vocab, lower=lower)
else:
self.tokenizer = BERTTokenizer(self.vocab, lower=lower)
self.transform = BERTSentenceTransform(tokenizer=self.tokenizer,
max_seq_length=self.max_seq_length,
pair=False)
Expand Down Expand Up @@ -153,12 +154,9 @@ def oov(self, batches, oov_way='avg'):
Parameters
----------
batches : List[(tokens_id,
sequence_outputs,
pooled_output].
batch token_ids (max_seq_length, ),
sequence_outputs (max_seq_length, dim, ),
pooled_output (dim, )
batches : List[(tokens_id, sequence_outputs)].
batch token_ids shape is (max_seq_length,),
sequence_outputs shape is (max_seq_length, dim)
oov_way : str
use **avg**, **sum** or **last** to get token embedding for those out of
vocabulary words
Expand All @@ -169,21 +167,29 @@ def oov(self, batches, oov_way='avg'):
List of tokens, and tokens embedding
"""
sentences = []
padding_idx, cls_idx, sep_idx = None, None, None
if self.vocab.padding_token:
padding_idx = self.vocab[self.vocab.padding_token]
if self.vocab.cls_token:
cls_idx = self.vocab[self.vocab.cls_token]
if self.vocab.sep_token:
sep_idx = self.vocab[self.vocab.sep_token]
for token_ids, sequence_outputs in batches:
tokens = []
tensors = []
oov_len = 1
for token_id, sequence_output in zip(token_ids, sequence_outputs):
if token_id == 1:
# [PAD] token, sequence is finished.
# [PAD] token, sequence is finished.
if padding_idx and token_id == padding_idx:
break
if token_id in (2, 3):
# [CLS], [SEP]
# [CLS], [SEP]
if cls_idx and token_id == cls_idx:
continue
if sep_idx and token_id == sep_idx:
continue
token = self.vocab.idx_to_token[token_id]
if token.startswith('##'):
token = token[2:]
tokens[-1] += token
if not self.tokenizer.is_first_subword(token):
tokens.append(token)
if oov_way == 'last':
tensors[-1] = sequence_output
else:
Expand Down Expand Up @@ -212,19 +218,21 @@ def oov(self, batches, oov_way='avg'):
parser.add_argument('--model', type=str, default='bert_12_768_12',
help='pre-trained model')
parser.add_argument('--dataset_name', type=str, default='book_corpus_wiki_en_uncased',
help='dataset')
help='name of the dataset used for pre-training')
parser.add_argument('--params_path', type=str, default=None,
help='path to a params file to load instead of the pretrained model.')
parser.add_argument('--max_seq_length', type=int, default=25,
parser.add_argument('--sentencepiece', type=str, default=None,
help='Path to the sentencepiece .model file for tokenization and vocab.')
parser.add_argument('--max_seq_length', type=int, default=128,
help='max length of each sequence')
parser.add_argument('--batch_size', type=int, default=256,
help='batch size')
parser.add_argument('--oov_way', type=str, default='avg',
help='how to handle oov\n'
'avg: average all oov embeddings to represent the original token\n'
'sum: sum all oov embeddings to represent the original token\n'
'last: use last oov embeddings to represent the original token\n')
parser.add_argument('--sentences', type=to_unicode, nargs='+', default=None,
help='how to handle subword embeddings\n'
'avg: average all subword embeddings to represent the original token\n'
'sum: sum all subword embeddings to represent the original token\n'
'last: use last subword embeddings to represent the original token\n')
parser.add_argument('--sentences', type=str, nargs='+', default=None,
help='sentence for encoding')
parser.add_argument('--file', type=str, default=None,
help='file for encoding')
Expand All @@ -240,7 +248,8 @@ def oov(self, batches, oov_way='avg'):
else:
context = mx.cpu()
bert_embedding = BertEmbedding(ctx=context, model=args.model, dataset_name=args.dataset_name,
max_seq_length=args.max_seq_length, batch_size=args.batch_size)
max_seq_length=args.max_seq_length, batch_size=args.batch_size,
params_path=args.params_path, sentencepiece=args.sentencepiece)
result = []
sents = []
if args.sentences:
Expand All @@ -255,7 +264,7 @@ def oov(self, batches, oov_way='avg'):
logger.error('Please specify --sentence or --file')

if result:
for sent, embeddings in zip(sents, result):
print('Text: {}'.format(sent))
_, tokens_embedding = embeddings
for _, embeddings in zip(sents, result):
sent, tokens_embedding = embeddings
print('Text: {}'.format(' '.join(sent)))
print('Tokens embedding: {}'.format(tokens_embedding))
2 changes: 1 addition & 1 deletion scripts/bert/finetune_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@
except ImportError:
# amp is not available
logging.info('Mixed precision training with float16 requires MXNet >= '
'1.5.0b20190627. Please consider upgrading your MXNet version.')
'1.5.1. Please consider upgrading your MXNet version.')
exit()

# model and loss
Expand Down
4 changes: 2 additions & 2 deletions scripts/bert/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,8 @@ The goal of this BERT Embedding is to obtain the token embedding from BERT's pre

.. code-block:: shell
python bert/embedding.py --sentences "GluonNLP is a toolkit that enables easy text preprocessing, datasets loading and neural models building to help you speed up your Natural Language Processing (NLP) research."
Text: GluonNLP is a toolkit that enables easy text preprocessing, datasets loading and neural models building to help you speed up your Natural Language Processing (NLP) research.
python embedding.py --sentences "GluonNLP is a toolkit that enables easy text preprocessing, datasets loading and neural models building to help you speed up your Natural Language Processing (NLP) research."
Text: g ##lu ##on ##nl ##p is a tool ##kit that enables easy text prep ##ro ##ces ##sing , data ##set ##s loading and neural models building to help you speed up your natural language processing ( nl ##p ) research .
Tokens embedding: [array([-0.11881411, -0.59530115, 0.627092 , ..., 0.00648153,
-0.03886228, 0.03406909], dtype=float32), array([-0.7995638 , -0.6540758 , -0.00521846, ..., -0.42272145,
-0.5787281 , 0.7021201 ], dtype=float32), array([-0.7406778 , -0.80276626, 0.3931962 , ..., -0.49068323,
Expand Down
5 changes: 4 additions & 1 deletion scripts/tests/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import pytest
import mxnet as mx
import gluonnlp as nlp

@pytest.mark.serial
@pytest.mark.remote_required
Expand Down Expand Up @@ -200,8 +201,10 @@ def test_bert_embedding(use_pretrained):
if use_pretrained:
args.extend(['--dtype', 'float32'])
else:
_, _ = nlp.model.get_model('bert_12_768_12', dataset_name='book_corpus_wiki_en_uncased',
pretrained=True, root='test_bert_embedding')
args.extend(['--params_path',
'~/.mxnet/models/bert_12_768_12_book_corpus_wiki_en_uncased-75cc780f.params'])
'test_bert_embedding/bert_12_768_12_book_corpus_wiki_en_uncased-75cc780f.params'])
process = subprocess.check_call([sys.executable, './scripts/bert/embedding.py'] + args)
time.sleep(5)

Expand Down

0 comments on commit ab41699

Please sign in to comment.