From 3922d0659a8605d54b639dc4d8a335c92c7244fe Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Thu, 8 Aug 2019 09:16:11 -0800 Subject: [PATCH] [model] Roberta converted weights (#870) * +roberta * fix vocab * remove self attention * add model store * add test * add doc * fix doc * fix tset * fix lint * separate class for roberta * fix lint * fix doc --- .../conversion_tools/convert_fairseq_model.py | 216 ++++++++++++ scripts/bert/conversion_tools/utils.py | 74 +++++ scripts/bert/index.rst | 25 ++ scripts/bert/pretraining_utils.py | 19 +- scripts/bert/run_pretraining.py | 3 +- scripts/bert/run_pretraining_hvd.py | 3 +- src/gluonnlp/data/utils.py | 2 + src/gluonnlp/model/__init__.py | 2 + src/gluonnlp/model/bert.py | 307 ++++++++++++++++-- tests/unittest/test_models.py | 34 +- 10 files changed, 659 insertions(+), 26 deletions(-) create mode 100644 scripts/bert/conversion_tools/convert_fairseq_model.py create mode 100644 scripts/bert/conversion_tools/utils.py diff --git a/scripts/bert/conversion_tools/convert_fairseq_model.py b/scripts/bert/conversion_tools/convert_fairseq_model.py new file mode 100644 index 0000000000..4a439904ed --- /dev/null +++ b/scripts/bert/conversion_tools/convert_fairseq_model.py @@ -0,0 +1,216 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# 'License'); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name,logging-format-interpolation +""" Script for converting Fairseq Roberta Model to Gluon. """ +import argparse +import logging +import os +import sys +import io +import numpy as np + +import torch +from fairseq.models.roberta import RobertaModel + +import mxnet as mx +import gluonnlp as nlp +from gluonnlp.model import BERTEncoder, BERTModel +from gluonnlp.model.bert import bert_hparams +from gluonnlp.data.utils import _load_pretrained_vocab + +from utils import get_hash, load_text_vocab, tf_vocab_to_gluon_vocab + +parser = argparse.ArgumentParser(description='Conversion script for Fairseq RoBERTa model', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--ckpt_dir', type=str, help='Full path to the roberta folder', + default='/home/ubuntu/roberta/roberta.base') +parser.add_argument('--model', type=str, help='Model type. ', + choices=['roberta_12_768_12', 'roberta_24_1024_16'], + default='roberta_12_768_12') +parser.add_argument('--verbose', action='store_true', help='Verbose logging') + +args = parser.parse_args() + +ckpt_dir = os.path.expanduser(args.ckpt_dir) + +ckpt = torch.load(os.path.join(ckpt_dir, 'model.pt')) +pytorch_params = ckpt['model'] + +if args.verbose: + print(ckpt['args']) + for k, v in pytorch_params.items(): + print(k, v.shape) + +# Load the model in fairseq +roberta = RobertaModel.from_pretrained(ckpt_dir) +roberta.eval() + +def fairseq_vocab_to_gluon_vocab(torch_vocab): + index_to_words = [None] * len(torch_vocab) + + bos_idx = torch_vocab.bos() + pad_idx = torch_vocab.pad() + eos_idx = torch_vocab.eos() + unk_idx = torch_vocab.unk() + + index_to_words[bos_idx] = torch_vocab.symbols[bos_idx] + index_to_words[pad_idx] = torch_vocab.symbols[pad_idx] + index_to_words[eos_idx] = torch_vocab.symbols[eos_idx] + index_to_words[unk_idx] = torch_vocab.symbols[unk_idx] + + specials = [bos_idx, pad_idx, eos_idx, unk_idx] + + openai_to_roberta = {} + openai_vocab = _load_pretrained_vocab('openai_webtext', '.') + + with io.open(os.path.join(ckpt_dir, 'dict.txt'), encoding='utf-8') as f: + for i, line in enumerate(f): + token, count = line.split(' ') + try: + fake_token = int(token) + openai_to_roberta[token] = i + len(specials) + except ValueError: + index_to_words[i + len(specials)] = token + + for idx, token in enumerate(openai_vocab.idx_to_token): + if str(idx) in openai_to_roberta: + index_to_words[openai_to_roberta[str(idx)]] = token + else: + assert token == u'', token + + mask_idx = torch_vocab.index(u'') + index_to_words[mask_idx] = torch_vocab.string([mask_idx]) + assert None not in index_to_words + word2idx = {} + for idx, token in enumerate(index_to_words): + word2idx[token] = idx + + vocab = nlp.vocab.Vocab(word2idx, token_to_idx=word2idx, + unknown_token=index_to_words[unk_idx], + padding_token=index_to_words[pad_idx], + bos_token=index_to_words[bos_idx], + eos_token=index_to_words[eos_idx], + mask_token=u'') + return vocab + +vocab = fairseq_vocab_to_gluon_vocab(roberta.task.dictionary) + +predefined_args = bert_hparams[args.model] + +# BERT encoder +encoder = BERTEncoder(attention_cell=predefined_args['attention_cell'], + num_layers=predefined_args['num_layers'], units=predefined_args['units'], + hidden_size=predefined_args['hidden_size'], + max_length=predefined_args['max_length'], + num_heads=predefined_args['num_heads'], scaled=predefined_args['scaled'], + dropout=predefined_args['dropout'], + use_residual=predefined_args['use_residual'], + layer_norm_eps=predefined_args['layer_norm_eps']) + +# BERT model +bert = BERTModel(encoder, len(vocab), + units=predefined_args['units'], embed_size=predefined_args['embed_size'], + embed_dropout=predefined_args['embed_dropout'], + word_embed=predefined_args['word_embed'], use_pooler=False, + use_token_type_embed=False, use_classifier=False) + +bert.initialize(init=mx.init.Normal(0.02)) + +ones = mx.nd.ones((2, 8)) +out = bert(ones, None, mx.nd.array([5, 6]), mx.nd.array([[1], [2]])) +params = bert._collect_params_with_prefix() + + + +mapping = { + 'decoder.2' : 'decoder.lm_head.layer_norm', + 'decoder.0' : 'decoder.lm_head.dense', + 'decoder.3' : 'decoder.lm_head', + 'encoder.layer_norm' : 'decoder.sentence_encoder.emb_layer_norm', + 'encoder.position_weight' : 'decoder.sentence_encoder.embed_positions.weight', + 'encoder.transformer_cells': 'decoder.sentence_encoder.layers', + 'attention_cell.proj_key.' : 'self_attn.in_proj_', + 'attention_cell.proj_value.' : 'self_attn.in_proj_', + 'attention_cell.proj_query.' : 'self_attn.in_proj_', + 'ffn.ffn_1' : 'fc1', + 'ffn.ffn_2' : 'fc2', + 'layer_norm.gamma' : 'layer_norm.weight', + 'layer_norm.beta' : 'layer_norm.bias', + 'ffn.layer_norm' : 'final_layer_norm', + 'word_embed.0.weight' : 'decoder.sentence_encoder.embed_tokens.weight', +} + +for i in range(24): + mapping['{}.layer_norm'.format(i)] = '{}.self_attn_layer_norm'.format(i) + mapping['{}.proj'.format(i)] = '{}.self_attn.out_proj'.format(i) + +# set parameter data +loaded_params = {} +visited_pytorch_params = {} +for name in params: + pytorch_name = name + for source, dest in mapping.items(): + pytorch_name = pytorch_name.replace(source, dest) + + assert pytorch_name in pytorch_params.keys(), 'Key ' + pytorch_name + ' for ' + name + ' not found.' + torch_arr = pytorch_params[pytorch_name].cpu() + # fairseq positional embedding starts with index 2 + if pytorch_name == 'decoder.sentence_encoder.embed_positions.weight': + torch_arr = torch_arr[2:] + + arr = mx.nd.array(torch_arr) + if 'attention_cell.proj' in name: + unfused = ['query', 'key', 'value'] + arrs = arr.split(num_outputs=3, axis=0) + for i, p in enumerate(unfused): + if p in name: + arr = arrs[i] + else: + assert arr.shape == params[name].shape, (arr.shape, params[name].shape, name, pytorch_name) + params[name].set_data(arr) + loaded_params[name] = True + visited_pytorch_params[pytorch_name] = True + +assert len(params) == len(loaded_params) +assert len(visited_pytorch_params) == len(pytorch_params), "Gluon model does not match PyTorch model. " \ + "Please fix the BERTModel hyperparameters\n" + str(len(visited_pytorch_params)) + ' v.s. ' + str(len(pytorch_params)) + + +texts = 'Hello world. abc, def and δΈ­ζ–‡!' +torch_tokens = roberta.encode(texts) + +torch_features = roberta.extract_features(torch_tokens) +pytorch_out = torch_features.detach().numpy() + +mx_tokenizer = nlp.data.GPT2BPETokenizer() +mx_tokens = [vocab.bos_token] + mx_tokenizer(texts) + [vocab.eos_token] +mx_data = vocab[mx_tokens] +print(mx_tokens) +print(vocab[mx_tokens]) +print(torch_tokens) +assert mx_data == torch_tokens.tolist() + +mx_out = bert(mx.nd.array([mx_data])) +print('stdev = ', np.std(mx_out.asnumpy() - pytorch_out)) +mx.test_utils.assert_almost_equal(mx_out.asnumpy(), pytorch_out, atol=1e-3, rtol=1e-3) +mx.test_utils.assert_almost_equal(mx_out.asnumpy(), pytorch_out, atol=5e-6, rtol=5e-6) + +bert.save_parameters(os.path.join(ckpt_dir, args.model + '.params')) +with io.open(os.path.join(ckpt_dir, args.model + '.vocab'), 'w', encoding='utf-8') as f: + f.write(vocab.to_json()) diff --git a/scripts/bert/conversion_tools/utils.py b/scripts/bert/conversion_tools/utils.py new file mode 100644 index 0000000000..fcba8159f2 --- /dev/null +++ b/scripts/bert/conversion_tools/utils.py @@ -0,0 +1,74 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utility functions for BERT.""" + +import logging +import collections +import hashlib +import io + +import mxnet as mx +import gluonnlp as nlp + +__all__ = ['tf_vocab_to_gluon_vocab', 'load_text_vocab'] + + +def tf_vocab_to_gluon_vocab(tf_vocab): + special_tokens = ['[UNK]', '[PAD]', '[SEP]', '[MASK]', '[CLS]'] + assert all(t in tf_vocab for t in special_tokens) + counter = nlp.data.count_tokens(tf_vocab.keys()) + vocab = nlp.vocab.BERTVocab(counter, token_to_idx=tf_vocab) + return vocab + + +def get_hash(filename): + sha1 = hashlib.sha1() + with open(filename, 'rb') as f: + while True: + data = f.read(1048576) + if not data: + break + sha1.update(data) + return sha1.hexdigest(), str(sha1.hexdigest())[:8] + + +def read_tf_checkpoint(path): + """read tensorflow checkpoint""" + from tensorflow.python import pywrap_tensorflow + tensors = {} + reader = pywrap_tensorflow.NewCheckpointReader(path) + var_to_shape_map = reader.get_variable_to_shape_map() + for key in sorted(var_to_shape_map): + tensor = reader.get_tensor(key) + tensors[key] = tensor + return tensors + +def load_text_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with io.open(vocab_file, 'r') as reader: + while True: + token = reader.readline() + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab diff --git a/scripts/bert/index.rst b/scripts/bert/index.rst index 708f872fc9..adba86aab9 100644 --- a/scripts/bert/index.rst +++ b/scripts/bert/index.rst @@ -47,6 +47,31 @@ The following pre-trained BERT models are available from the **gluonnlp.model.ge where **bert_12_768_12** refers to the BERT BASE model, and **bert_24_1024_16** refers to the BERT LARGE model. +.. code-block:: python + + import gluonnlp as nlp; import mxnet as mx; + model, vocab = nlp.model.get_model('bert_12_768_12', dataset_name='book_corpus_wiki_en_uncased', use_classifier=False); + tokenizer = nlp.data.BERTTokenizer(vocab, lower=True); + transform = nlp.data.BERTSentenceTransform(tokenizer, max_seq_length=512, pair=False, pad=False); + sample = transform(['Hello world!']); + words, valid_len, segments = mx.nd.array([sample[0]]), mx.nd.array([sample[1]]), mx.nd.array([sample[2]]); + seq_encoding, cls_encoding = model(words, segments, valid_len); + +Additionally, GluonNLP supports the "`RoBERTa `_" model: + ++-----------------------------------------+-------------------+--------------------+ +| | roberta_12_768_12 | roberta_24_1024_16 | ++=========================================+===================+====================+ +| openwebtext_ccnews_stories_books_cased | βœ“ | βœ“ | ++-----------------------------------------+-------------------+--------------------+ + +.. code-block:: python + + import gluonnlp as nlp; import mxnet as mx; + model, vocab = nlp.model.get_model('roberta_12_768_12', dataset_name='openwebtext_ccnews_stories_books_cased'); + tokenizer = nlp.data.GPT2BPETokenizer(); + text = [vocab.bos_token] + tokenizer('Hello world!') + [vocab.eos_token]; + seq_encoding = model(mx.nd.array([vocab[text]])) .. hint:: diff --git a/scripts/bert/pretraining_utils.py b/scripts/bert/pretraining_utils.py index 48073bf815..f5c7a418fe 100644 --- a/scripts/bert/pretraining_utils.py +++ b/scripts/bert/pretraining_utils.py @@ -39,7 +39,7 @@ __all__ = ['get_model_loss', 'get_pretrain_data_npz', 'get_dummy_dataloader', 'save_parameters', 'save_states', 'evaluate', 'forward', 'split_and_load', - 'get_argparser', 'get_pretrain_data_text', 'generate_dev_set'] + 'get_argparser', 'get_pretrain_data_text', 'generate_dev_set', 'profile'] def get_model_loss(ctx, model, pretrained, dataset_name, vocab, dtype, ckpt_dir=None, start_step=None): @@ -505,3 +505,20 @@ def generate_dev_set(tokenizer, vocab, cache_file, args): 1, args.num_data_workers, worker_pool, cache_file)) logging.info('Done generating validation set on rank 0.') + +def profile(curr_step, start_step, end_step, profile_name='profile.json', + early_exit=True): + """profile the program between [start_step, end_step).""" + if curr_step == start_step: + mx.nd.waitall() + mx.profiler.set_config(profile_memory=False, profile_symbolic=True, + profile_imperative=True, filename=profile_name, + aggregate_stats=True) + mx.profiler.set_state('run') + elif curr_step == end_step: + mx.nd.waitall() + mx.profiler.set_state('stop') + logging.info(mx.profiler.dumps()) + mx.profiler.dump() + if early_exit: + exit() diff --git a/scripts/bert/run_pretraining.py b/scripts/bert/run_pretraining.py index 33d8d3b933..f2ac2726f4 100644 --- a/scripts/bert/run_pretraining.py +++ b/scripts/bert/run_pretraining.py @@ -37,11 +37,10 @@ import mxnet as mx import gluonnlp as nlp -from utils import profile from fp16_utils import FP16Trainer from pretraining_utils import get_model_loss, get_pretrain_data_npz, get_dummy_dataloader from pretraining_utils import log, evaluate, forward, split_and_load, get_argparser -from pretraining_utils import save_parameters, save_states +from pretraining_utils import save_parameters, save_states, profile # arg parser parser = get_argparser() diff --git a/scripts/bert/run_pretraining_hvd.py b/scripts/bert/run_pretraining_hvd.py index 78bf55dd51..2b1f974619 100644 --- a/scripts/bert/run_pretraining_hvd.py +++ b/scripts/bert/run_pretraining_hvd.py @@ -40,11 +40,10 @@ import mxnet as mx import gluonnlp as nlp -from utils import profile from fp16_utils import FP16Trainer from pretraining_utils import get_model_loss, get_pretrain_data_npz, get_dummy_dataloader from pretraining_utils import split_and_load, log, evaluate, forward, get_argparser -from pretraining_utils import save_parameters, save_states +from pretraining_utils import save_parameters, save_states, profile from pretraining_utils import get_pretrain_data_text, generate_dev_set # parser diff --git a/src/gluonnlp/data/utils.py b/src/gluonnlp/data/utils.py index 925c2c3991..888ebc4c6d 100644 --- a/src/gluonnlp/data/utils.py +++ b/src/gluonnlp/data/utils.py @@ -225,6 +225,8 @@ def _slice_pad_length(num_items, length, overlap=0): 'book_corpus_wiki_en_uncased': 'a66073971aa0b1a262453fe51342e57166a8abcf', 'openwebtext_book_corpus_wiki_en_uncased': 'a66073971aa0b1a262453fe51342e57166a8abcf', + 'openwebtext_ccnews_stories_books_cased': + '2b804f8f90f9f93c07994b703ce508725061cf43', 'wiki_multilingual_cased': '0247cb442074237c38c62021f36b7a4dbd2e55f7', 'wiki_cn_cased': 'ddebd8f3867bca5a61023f73326fb125cf12b4f5', 'wiki_multilingual_uncased': '2b2514cc539047b9179e9d98a4e68c36db05c97a', diff --git a/src/gluonnlp/model/__init__.py b/src/gluonnlp/model/__init__.py index ee3c3044cb..b77c4a9288 100644 --- a/src/gluonnlp/model/__init__.py +++ b/src/gluonnlp/model/__init__.py @@ -142,6 +142,8 @@ def get_model(name, **kwargs): 'transformer_en_de_512': transformer_en_de_512, 'bert_12_768_12' : bert_12_768_12, 'bert_24_1024_16' : bert_24_1024_16, + 'roberta_12_768_12' : roberta_12_768_12, + 'roberta_24_1024_16' : roberta_24_1024_16, 'ernie_12_768_12' : ernie_12_768_12} name = name.lower() if name not in models: diff --git a/src/gluonnlp/model/bert.py b/src/gluonnlp/model/bert.py index 1582b431e9..30a5a08d59 100644 --- a/src/gluonnlp/model/bert.py +++ b/src/gluonnlp/model/bert.py @@ -17,10 +17,11 @@ # specific language governing permissions and limitations # under the License. """BERT models.""" +# pylint: disable=too-many-lines -__all__ = ['BERTModel', 'BERTEncoder', 'BERTEncoderCell', 'BERTPositionwiseFFN', +__all__ = ['BERTModel', 'RoBERTaModel', 'BERTEncoder', 'BERTEncoderCell', 'BERTPositionwiseFFN', 'BERTLayerNorm', 'bert_12_768_12', 'bert_24_1024_16', - 'ernie_12_768_12'] + 'ernie_12_768_12', 'roberta_12_768_12', 'roberta_24_1024_16'] import os from mxnet.gluon import Block @@ -41,6 +42,7 @@ class BERTLayerNorm(nn.LayerNorm): """BERT style Layer Normalization. + Epsilon is added inside the square root and set to 1e-12 by default. Inputs: @@ -51,6 +53,7 @@ class BERTLayerNorm(nn.LayerNorm): def __init__(self, epsilon=1e-12, in_channels=0, prefix=None, params=None): super(BERTLayerNorm, self).__init__(epsilon=epsilon, in_channels=in_channels, prefix=prefix, params=params) + def hybrid_forward(self, F, data, gamma, beta): """forward computation.""" return F.LayerNorm(data, gamma=gamma, beta=beta, axis=self._axis, eps=self._epsilon) @@ -278,7 +281,7 @@ class BERTModel(Block): vocab_size : int or None, default None The size of the vocabulary. token_type_vocab_size : int or None, default None - The vocabulary size of token types. + The vocabulary size of token types (number of segments). units : int or None, default None Number of units for the final pooler layer. embed_size : int or None, default None @@ -294,8 +297,8 @@ class BERTModel(Block): The word embedding. If set to None, word_embed will be constructed using embed_size and embed_dropout. token_type_embed : Block or None, default None - The token type embedding. If set to None and the token_type_embed will be constructed using - embed_size and embed_dropout. + The token type embedding (segment embedding). If set to None and the token_type_embed will + be constructed using embed_size and embed_dropout. use_pooler : bool, default True Whether to include the pooler which converts the encoded sequence tensor of shape (batch_size, seq_length, units) to a tensor of shape (batch_size, units) @@ -304,6 +307,8 @@ class BERTModel(Block): Whether to include the decoder for masked language model prediction. use_classifier : bool, default True Whether to include the classifier for next sentence classification. + use_token_type_embed : bool, default True + Whether to include token type embedding (segment embedding). prefix : str or None See document of `mx.gluon.Block`. params : ParameterDict or None @@ -311,7 +316,7 @@ class BERTModel(Block): Inputs: - **inputs**: input sequence tensor, shape (batch_size, seq_length) - - **token_types**: input token type tensor, shape (batch_size, seq_length). + - **token_types**: optional input token type tensor, shape (batch_size, seq_length). If the inputs contain two sequences, then the token type of the first sequence differs from that of the second one. - **valid_length**: optional tensor of input sequence valid lengths, shape (batch_size,) @@ -338,20 +343,22 @@ class BERTModel(Block): def __init__(self, encoder, vocab_size=None, token_type_vocab_size=None, units=None, embed_size=None, embed_dropout=0.0, embed_initializer=None, word_embed=None, token_type_embed=None, use_pooler=True, use_decoder=True, - use_classifier=True, prefix=None, params=None): + use_classifier=True, use_token_type_embed=True, prefix=None, params=None): super(BERTModel, self).__init__(prefix=prefix, params=params) self._use_decoder = use_decoder self._use_classifier = use_classifier self._use_pooler = use_pooler + self._use_token_type_embed = use_token_type_embed self._vocab_size = vocab_size self.encoder = encoder # Construct word embedding self.word_embed = self._get_embed(word_embed, vocab_size, embed_size, embed_initializer, embed_dropout, 'word_embed_') # Construct token type embedding - self.token_type_embed = self._get_embed(token_type_embed, token_type_vocab_size, - embed_size, embed_initializer, embed_dropout, - 'token_type_embed_') + if use_token_type_embed: + self.token_type_embed = self._get_embed(token_type_embed, token_type_vocab_size, + embed_size, embed_initializer, embed_dropout, + 'token_type_embed_') if self._use_pooler: # Construct pooler self.pooler = self._get_pooler(units, 'pooler_') @@ -409,7 +416,7 @@ def _get_pooler(self, units, prefix): prefix=prefix) return pooler - def forward(self, inputs, token_types, valid_length=None, masked_positions=None): # pylint: disable=arguments-differ + def forward(self, inputs, token_types=None, valid_length=None, masked_positions=None): # pylint: disable=arguments-differ """Generate the representation given the inputs. This is used in training or fine-tuning a BERT model. @@ -433,9 +440,7 @@ def forward(self, inputs, token_types, valid_length=None, masked_positions=None) if self._use_classifier: next_sentence_classifier_out = self.classifier(pooled_out) outputs.append(next_sentence_classifier_out) - if self._use_decoder: - assert masked_positions is not None, \ - 'masked_positions tensor is required for decoding masked language model' + if self._use_decoder and masked_positions is not None: decoder_out = self._decode(output, masked_positions) outputs.append(decoder_out) return tuple(outputs) if len(outputs) > 1 else outputs[0] @@ -446,9 +451,10 @@ def _encode_sequence(self, inputs, token_types, valid_length=None): This is used for pre-training or fine-tuning a BERT model. """ # embedding - word_embedding = self.word_embed(inputs) - type_embedding = self.token_type_embed(token_types) - embedding = word_embedding + type_embedding + embedding = self.word_embed(inputs) + if self._use_token_type_embed: + type_embedding = self.token_type_embed(token_types) + embedding = embedding + type_embedding # encoding outputs, additional_outputs = self.encoder(embedding, None, valid_length) return outputs, additional_outputs @@ -492,6 +498,76 @@ def _decode(self, sequence, masked_positions): decoded = self.decoder(encoded) return decoded +class RoBERTaModel(BERTModel): + """Generic Model for BERT (Bidirectional Encoder Representations from Transformers). + + Parameters + ---------- + encoder : BERTEncoder + Bidirectional encoder that encodes the input sentence. + vocab_size : int or None, default None + The size of the vocabulary. + units : int or None, default None + Number of units for the final pooler layer. + embed_size : int or None, default None + Size of the embedding vectors. It is used to generate the word and token type + embeddings if word_embed and token_type_embed are None. + embed_dropout : float, default 0.0 + Dropout rate of the embedding weights. It is used to generate the source and target + embeddings if word_embed and token_type_embed are None. + embed_initializer : Initializer, default None + Initializer of the embedding weights. It is used to generate the source and target + embeddings if word_embed and token_type_embed are None. + word_embed : Block or None, default None + The word embedding. If set to None, word_embed will be constructed using embed_size and + embed_dropout. + use_decoder : bool, default True + Whether to include the decoder for masked language model prediction. + prefix : str or None + See document of `mx.gluon.Block`. + params : ParameterDict or None + See document of `mx.gluon.Block`. + + Inputs: + - **inputs**: input sequence tensor, shape (batch_size, seq_length) + - **valid_length**: optional tensor of input sequence valid lengths, shape (batch_size,) + - **masked_positions**: optional tensor of position of tokens for masked LM decoding, + shape (batch_size, num_masked_positions). + + Outputs: + - **sequence_outputs**: Encoded sequence, which can be either a tensor of the last + layer of the Encoder, or a list of all sequence encodings of all layers. + In both cases shape of the tensor(s) is/are (batch_size, seq_length, units). + - **attention_outputs**: output list of all intermediate encodings per layer + Returned only if BERTEncoder.output_attention is True. + List of num_layers length of tensors of shape + (num_masks, num_attention_heads, seq_length, seq_length) + - **masked_lm_outputs**: output tensor of sequence decoding for masked language model + prediction. Returned only if use_decoder True. + Shape (batch_size, num_masked_positions, vocab_size) + """ + + def __init__(self, encoder, vocab_size=None, units=None, + embed_size=None, embed_dropout=0.0, embed_initializer=None, + word_embed=None, use_decoder=True, prefix=None, params=None): + super(RoBERTaModel, self).__init__(encoder, vocab_size=vocab_size, + token_type_vocab_size=None, units=units, + embed_size=embed_size, embed_dropout=embed_dropout, + embed_initializer=embed_initializer, + word_embed=word_embed, token_type_embed=None, + use_pooler=False, use_decoder=use_decoder, + use_classifier=False, use_token_type_embed=False, + prefix=prefix, params=params) + + def forward(self, inputs, valid_length=None, masked_positions=None): # pylint: disable=arguments-differ + """Generate the representation given the inputs. + + This is used in training or fine-tuning a BERT model. + """ + return super(RoBERTaModel, self).forward(inputs, token_types=None, + valid_length=valid_length, + masked_positions=masked_positions) + ############################################################################### # GET MODEL # ############################################################################### @@ -503,6 +579,10 @@ def _decode(self, sequence, masked_positions): ('75cc780f085e8007b3bf6769c6348bb1ff9a3074', 'bert_12_768_12_book_corpus_wiki_en_uncased'), ('a56e24015a777329c795eed4ed21c698af03c9ff', 'bert_12_768_12_openwebtext_book_corpus_wiki_en_uncased'), + ('5cf21fcddb5ae1a4c21c61201643460c9d65d3b0', + 'roberta_12_768_12_openwebtext_ccnews_stories_books_cased'), + ('d1b7163e9628e2fd51c9a9f3a0dc519d4fc24add', + 'roberta_24_1024_16_openwebtext_ccnews_stories_books_cased'), ('237f39851b24f0b56d70aa20efd50095e3926e26', 'bert_12_768_12_wiki_multilingual_uncased'), ('b0f57a207f85a7d361bb79de80756a8c9a4276f7', 'bert_12_768_12_wiki_multilingual_cased'), ('885ebb9adc249a170c5576e90e88cfd1bbd98da6', 'bert_12_768_12_wiki_cn_cased'), @@ -521,6 +601,38 @@ def _decode(self, sequence, masked_positions): ('f869f3f89e4237a769f1b7edcbdfe8298b480052', 'ernie_12_768_12_baidu_ernie_uncased'), ]}) +roberta_12_768_12_hparams = { + 'attention_cell': 'multi_head', + 'num_layers': 12, + 'units': 768, + 'hidden_size': 3072, + 'max_length': 512, + 'num_heads': 12, + 'scaled': True, + 'dropout': 0.1, + 'use_residual': True, + 'embed_size': 768, + 'embed_dropout': 0.1, + 'word_embed': None, + 'layer_norm_eps': 1e-5 +} + +roberta_24_1024_16_hparams = { + 'attention_cell': 'multi_head', + 'num_layers': 24, + 'units': 1024, + 'hidden_size': 4096, + 'max_length': 512, + 'num_heads': 16, + 'scaled': True, + 'dropout': 0.1, + 'use_residual': True, + 'embed_size': 1024, + 'embed_dropout': 0.1, + 'word_embed': None, + 'layer_norm_eps': 1e-5 +} + bert_12_768_12_hparams = { 'attention_cell': 'multi_head', 'num_layers': 12, @@ -574,6 +686,8 @@ def _decode(self, sequence, masked_positions): bert_hparams = { 'bert_12_768_12': bert_12_768_12_hparams, 'bert_24_1024_16': bert_24_1024_16_hparams, + 'roberta_12_768_12': roberta_12_768_12_hparams, + 'roberta_24_1024_16': roberta_24_1024_16_hparams, 'ernie_12_768_12': ernie_12_768_12_hparams } @@ -704,6 +818,79 @@ def bert_24_1024_16(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu() pretrained_allow_missing=pretrained_allow_missing, **kwargs) +def roberta_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(), + use_decoder=True, + root=os.path.join(get_home_dir(), 'models'), **kwargs): + """Generic RoBERTa BASE model. + + The number of layers (L) is 12, number of units (H) is 768, and the + number of self-attention heads (A) is 12. + + Parameters + ---------- + dataset_name : str or None, default None + If not None, the dataset name is used to load a vocabulary for the + dataset. If the `pretrained` argument is set to True, the dataset name + is further used to select the pretrained parameters to load. + Options include 'book_corpus_wiki_en_uncased' and 'book_corpus_wiki_en_cased'. + vocab : gluonnlp.vocab.Vocab or None, default None + Vocabulary for the dataset. Must be provided if dataset_name is not + specified. Ignored if dataset_name is specified. + pretrained : bool, default True + Whether to load the pretrained weights for model. + ctx : Context, default CPU + The context in which to load the pretrained weights. + root : str, default '$MXNET_HOME/models' + Location for keeping the model parameters. + MXNET_HOME defaults to '~/.mxnet'. + use_decoder : bool, default True + Whether to include the decoder for masked language model prediction. + + Returns + ------- + RoBERTaModel, gluonnlp.vocab.Vocab + """ + return get_roberta_model(model_name='roberta_12_768_12', vocab=vocab, dataset_name=dataset_name, + pretrained=pretrained, ctx=ctx, + use_decoder=use_decoder, root=root, **kwargs) + + +def roberta_24_1024_16(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(), + use_decoder=True, + root=os.path.join(get_home_dir(), 'models'), **kwargs): + """Generic RoBERTa LARGE model. + + The number of layers (L) is 24, number of units (H) is 1024, and the + number of self-attention heads (A) is 16. + + Parameters + ---------- + dataset_name : str or None, default None + If not None, the dataset name is used to load a vocabulary for the + dataset. If the `pretrained` argument is set to True, the dataset name + is further used to select the pretrained parameters to load. + Options include 'book_corpus_wiki_en_uncased' and 'book_corpus_wiki_en_cased'. + vocab : gluonnlp.vocab.Vocab or None, default None + Vocabulary for the dataset. Must be provided if dataset_name is not + specified. Ignored if dataset_name is specified. + pretrained : bool, default True + Whether to load the pretrained weights for model. + ctx : Context, default CPU + The context in which to load the pretrained weights. + root : str, default '$MXNET_HOME/models' + Location for keeping the model parameters. + MXNET_HOME defaults to '~/.mxnet'. + use_decoder : bool, default True + Whether to include the decoder for masked language model prediction. + + Returns + ------- + RoBERTaModel, gluonnlp.vocab.Vocab + """ + return get_roberta_model(model_name='roberta_24_1024_16', vocab=vocab, + dataset_name=dataset_name, pretrained=pretrained, ctx=ctx, + use_decoder=use_decoder, root=root, **kwargs) + def ernie_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(), root=os.path.join(get_home_dir(), 'models'), use_pooler=True, use_decoder=True, use_classifier=True, **kwargs): @@ -751,9 +938,87 @@ def ernie_12_768_12(dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu() pretrained_allow_missing=False, **kwargs) +def get_roberta_model(model_name=None, dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(), + use_decoder=True, output_attention=False, output_all_encodings=False, + root=os.path.join(get_home_dir(), 'models'), **kwargs): + """Any RoBERTa pretrained model. + + Parameters + ---------- + model_name : str or None, default None + Options include 'bert_24_1024_16' and 'bert_12_768_12'. + dataset_name : str or None, default None + If not None, the dataset name is used to load a vocabulary for the + dataset. If the `pretrained` argument is set to True, the dataset name + is further used to select the pretrained parameters to load. + The supported datasets for model_name of either roberta_24_1024_16 and + roberta_12_768_12 include 'openwebtext_ccnews_stories_books'. + vocab : gluonnlp.vocab.Vocab or None, default None + Vocabulary for the dataset. Must be provided if dataset_name is not + specified. Ignored if dataset_name is specified. + pretrained : bool, default True + Whether to load the pretrained weights for model. + ctx : Context, default CPU + The context in which to load the pretrained weights. + root : str, default '$MXNET_HOME/models' + Location for keeping the model parameters. + MXNET_HOME defaults to '~/.mxnet'. + use_decoder : bool, default True + Whether to include the decoder for masked language model prediction. + Note that + 'biobert_v1.0_pmc', 'biobert_v1.0_pubmed', 'biobert_v1.0_pubmed_pmc', + 'biobert_v1.1_pubmed', + 'clinicalbert' + do not include these parameters. + output_attention : bool, default False + Whether to include attention weights of each encoding cell to the output. + output_all_encodings : bool, default False + Whether to output encodings of all encoder cells. + + Returns + ------- + RoBERTaModel, gluonnlp.vocab.Vocab + """ + predefined_args = bert_hparams[model_name] + mutable_args = ['use_residual', 'dropout', 'embed_dropout', 'word_embed'] + mutable_args = frozenset(mutable_args) + assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \ + 'Cannot override predefined model settings.' + predefined_args.update(kwargs) + # encoder + encoder = BERTEncoder(attention_cell=predefined_args['attention_cell'], + num_layers=predefined_args['num_layers'], + units=predefined_args['units'], + hidden_size=predefined_args['hidden_size'], + max_length=predefined_args['max_length'], + num_heads=predefined_args['num_heads'], + scaled=predefined_args['scaled'], + dropout=predefined_args['dropout'], + output_attention=output_attention, + output_all_encodings=output_all_encodings, + use_residual=predefined_args['use_residual'], + activation=predefined_args.get('activation', 'gelu'), + layer_norm_eps=predefined_args.get('layer_norm_eps', None)) + + from ..vocab import Vocab + bert_vocab = _load_vocab(dataset_name, vocab, root, cls=Vocab) + # BERT + net = RoBERTaModel(encoder, len(bert_vocab), + units=predefined_args['units'], + embed_size=predefined_args['embed_size'], + embed_dropout=predefined_args['embed_dropout'], + word_embed=predefined_args['word_embed'], + use_decoder=use_decoder) + if pretrained: + ignore_extra = not use_decoder + _load_pretrained_params(net, model_name, dataset_name, root, ctx, ignore_extra=ignore_extra, + allow_missing=False) + return net, bert_vocab + def get_bert_model(model_name=None, dataset_name=None, vocab=None, pretrained=True, ctx=mx.cpu(), use_pooler=True, use_decoder=True, use_classifier=True, output_attention=False, - output_all_encodings=False, root=os.path.join(get_home_dir(), 'models'), + output_all_encodings=False, use_token_type_embed=True, + root=os.path.join(get_home_dir(), 'models'), pretrained_allow_missing=False, **kwargs): """Any BERT pretrained model. @@ -842,8 +1107,9 @@ def get_bert_model(model_name=None, dataset_name=None, vocab=None, pretrained=Tr use_residual=predefined_args['use_residual'], activation=predefined_args.get('activation', 'gelu'), layer_norm_eps=predefined_args.get('layer_norm_eps', None)) - # bert_vocab + from ..vocab import BERTVocab + # bert_vocab bert_vocab = _load_vocab(dataset_name, vocab, root, cls=BERTVocab) # BERT net = BERTModel(encoder, len(bert_vocab), @@ -853,7 +1119,8 @@ def get_bert_model(model_name=None, dataset_name=None, vocab=None, pretrained=Tr embed_dropout=predefined_args['embed_dropout'], word_embed=predefined_args['word_embed'], use_pooler=use_pooler, use_decoder=use_decoder, - use_classifier=use_classifier) + use_classifier=use_classifier, + use_token_type_embed=use_token_type_embed) if pretrained: ignore_extra = not (use_pooler and use_decoder and use_classifier) _load_pretrained_params(net, model_name, dataset_name, root, ctx, ignore_extra=ignore_extra, diff --git a/tests/unittest/test_models.py b/tests/unittest/test_models.py index 0f65befe22..1552359ccd 100644 --- a/tests/unittest/test_models.py +++ b/tests/unittest/test_models.py @@ -96,6 +96,39 @@ def test_transformer_models(): mx.nd.waitall() +@pytest.mark.serial +@pytest.mark.remote_required +def test_pretrained_roberta_models(): + models = ['roberta_12_768_12', 'roberta_24_1024_16'] + pretrained_datasets = ['openwebtext_ccnews_stories_books_cased'] + + vocab_size = {'openwebtext_ccnews_stories_books_cased': 50265} + special_tokens = ['', '', '', '', ''] + ones = mx.nd.ones((2, 10)) + valid_length = mx.nd.ones((2,)) + positions = mx.nd.zeros((2, 3)) + for model_name in models: + for dataset in pretrained_datasets: + eprint('testing forward for %s on %s' % (model_name, dataset)) + + model, vocab = nlp.model.get_model(model_name, dataset_name=dataset, + pretrained=True, + root='tests/data/model/') + + assert len(vocab) == vocab_size[dataset] + for token in special_tokens: + assert token in vocab, "Token %s not found in the vocab" % token + assert vocab['RandomWordByHaibin'] == vocab[vocab.unknown_token] + assert vocab.padding_token == '' + assert vocab.unknown_token == '' + assert vocab.bos_token == '' + assert vocab.eos_token == '' + + output = model(ones, valid_length, positions) + output[0].wait_to_read() + del model + mx.nd.waitall() + @pytest.mark.serial @pytest.mark.remote_required @pytest.mark.parametrize('disable_missing_parameters', [False, True]) @@ -565,4 +598,3 @@ def test_transformer_encoder(): outputs.wait_to_read() mx.nd.waitall() assert outputs.shape == (batch_size, seq_length, units) -