Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[model] Roberta converted weights (#870)
Browse files Browse the repository at this point in the history
* +roberta

* fix vocab

* remove self attention

* add model store

* add test

* add doc

* fix doc

* fix tset

* fix lint

* separate class for roberta

* fix lint

* fix doc
  • Loading branch information
eric-haibin-lin authored and szha committed Aug 8, 2019
1 parent da936e0 commit 3922d06
Show file tree
Hide file tree
Showing 10 changed files with 659 additions and 26 deletions.
216 changes: 216 additions & 0 deletions scripts/bert/conversion_tools/convert_fairseq_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# coding: utf-8

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# 'License'); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint:disable=redefined-outer-name,logging-format-interpolation
""" Script for converting Fairseq Roberta Model to Gluon. """
import argparse
import logging
import os
import sys
import io
import numpy as np

import torch
from fairseq.models.roberta import RobertaModel

import mxnet as mx
import gluonnlp as nlp
from gluonnlp.model import BERTEncoder, BERTModel
from gluonnlp.model.bert import bert_hparams
from gluonnlp.data.utils import _load_pretrained_vocab

from utils import get_hash, load_text_vocab, tf_vocab_to_gluon_vocab

parser = argparse.ArgumentParser(description='Conversion script for Fairseq RoBERTa model',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--ckpt_dir', type=str, help='Full path to the roberta folder',
default='/home/ubuntu/roberta/roberta.base')
parser.add_argument('--model', type=str, help='Model type. ',
choices=['roberta_12_768_12', 'roberta_24_1024_16'],
default='roberta_12_768_12')
parser.add_argument('--verbose', action='store_true', help='Verbose logging')

args = parser.parse_args()

ckpt_dir = os.path.expanduser(args.ckpt_dir)

ckpt = torch.load(os.path.join(ckpt_dir, 'model.pt'))
pytorch_params = ckpt['model']

if args.verbose:
print(ckpt['args'])
for k, v in pytorch_params.items():
print(k, v.shape)

# Load the model in fairseq
roberta = RobertaModel.from_pretrained(ckpt_dir)
roberta.eval()

def fairseq_vocab_to_gluon_vocab(torch_vocab):
index_to_words = [None] * len(torch_vocab)

bos_idx = torch_vocab.bos()
pad_idx = torch_vocab.pad()
eos_idx = torch_vocab.eos()
unk_idx = torch_vocab.unk()

index_to_words[bos_idx] = torch_vocab.symbols[bos_idx]
index_to_words[pad_idx] = torch_vocab.symbols[pad_idx]
index_to_words[eos_idx] = torch_vocab.symbols[eos_idx]
index_to_words[unk_idx] = torch_vocab.symbols[unk_idx]

specials = [bos_idx, pad_idx, eos_idx, unk_idx]

openai_to_roberta = {}
openai_vocab = _load_pretrained_vocab('openai_webtext', '.')

with io.open(os.path.join(ckpt_dir, 'dict.txt'), encoding='utf-8') as f:
for i, line in enumerate(f):
token, count = line.split(' ')
try:
fake_token = int(token)
openai_to_roberta[token] = i + len(specials)
except ValueError:
index_to_words[i + len(specials)] = token

for idx, token in enumerate(openai_vocab.idx_to_token):
if str(idx) in openai_to_roberta:
index_to_words[openai_to_roberta[str(idx)]] = token
else:
assert token == u'<mask>', token

mask_idx = torch_vocab.index(u'<mask>')
index_to_words[mask_idx] = torch_vocab.string([mask_idx])
assert None not in index_to_words
word2idx = {}
for idx, token in enumerate(index_to_words):
word2idx[token] = idx

vocab = nlp.vocab.Vocab(word2idx, token_to_idx=word2idx,
unknown_token=index_to_words[unk_idx],
padding_token=index_to_words[pad_idx],
bos_token=index_to_words[bos_idx],
eos_token=index_to_words[eos_idx],
mask_token=u'<mask>')
return vocab

vocab = fairseq_vocab_to_gluon_vocab(roberta.task.dictionary)

predefined_args = bert_hparams[args.model]

# BERT encoder
encoder = BERTEncoder(attention_cell=predefined_args['attention_cell'],
num_layers=predefined_args['num_layers'], units=predefined_args['units'],
hidden_size=predefined_args['hidden_size'],
max_length=predefined_args['max_length'],
num_heads=predefined_args['num_heads'], scaled=predefined_args['scaled'],
dropout=predefined_args['dropout'],
use_residual=predefined_args['use_residual'],
layer_norm_eps=predefined_args['layer_norm_eps'])

# BERT model
bert = BERTModel(encoder, len(vocab),
units=predefined_args['units'], embed_size=predefined_args['embed_size'],
embed_dropout=predefined_args['embed_dropout'],
word_embed=predefined_args['word_embed'], use_pooler=False,
use_token_type_embed=False, use_classifier=False)

bert.initialize(init=mx.init.Normal(0.02))

ones = mx.nd.ones((2, 8))
out = bert(ones, None, mx.nd.array([5, 6]), mx.nd.array([[1], [2]]))
params = bert._collect_params_with_prefix()



mapping = {
'decoder.2' : 'decoder.lm_head.layer_norm',
'decoder.0' : 'decoder.lm_head.dense',
'decoder.3' : 'decoder.lm_head',
'encoder.layer_norm' : 'decoder.sentence_encoder.emb_layer_norm',
'encoder.position_weight' : 'decoder.sentence_encoder.embed_positions.weight',
'encoder.transformer_cells': 'decoder.sentence_encoder.layers',
'attention_cell.proj_key.' : 'self_attn.in_proj_',
'attention_cell.proj_value.' : 'self_attn.in_proj_',
'attention_cell.proj_query.' : 'self_attn.in_proj_',
'ffn.ffn_1' : 'fc1',
'ffn.ffn_2' : 'fc2',
'layer_norm.gamma' : 'layer_norm.weight',
'layer_norm.beta' : 'layer_norm.bias',
'ffn.layer_norm' : 'final_layer_norm',
'word_embed.0.weight' : 'decoder.sentence_encoder.embed_tokens.weight',
}

for i in range(24):
mapping['{}.layer_norm'.format(i)] = '{}.self_attn_layer_norm'.format(i)
mapping['{}.proj'.format(i)] = '{}.self_attn.out_proj'.format(i)

# set parameter data
loaded_params = {}
visited_pytorch_params = {}
for name in params:
pytorch_name = name
for source, dest in mapping.items():
pytorch_name = pytorch_name.replace(source, dest)

assert pytorch_name in pytorch_params.keys(), 'Key ' + pytorch_name + ' for ' + name + ' not found.'
torch_arr = pytorch_params[pytorch_name].cpu()
# fairseq positional embedding starts with index 2
if pytorch_name == 'decoder.sentence_encoder.embed_positions.weight':
torch_arr = torch_arr[2:]

arr = mx.nd.array(torch_arr)
if 'attention_cell.proj' in name:
unfused = ['query', 'key', 'value']
arrs = arr.split(num_outputs=3, axis=0)
for i, p in enumerate(unfused):
if p in name:
arr = arrs[i]
else:
assert arr.shape == params[name].shape, (arr.shape, params[name].shape, name, pytorch_name)
params[name].set_data(arr)
loaded_params[name] = True
visited_pytorch_params[pytorch_name] = True

assert len(params) == len(loaded_params)
assert len(visited_pytorch_params) == len(pytorch_params), "Gluon model does not match PyTorch model. " \
"Please fix the BERTModel hyperparameters\n" + str(len(visited_pytorch_params)) + ' v.s. ' + str(len(pytorch_params))


texts = 'Hello world. abc, def and 中文!'
torch_tokens = roberta.encode(texts)

torch_features = roberta.extract_features(torch_tokens)
pytorch_out = torch_features.detach().numpy()

mx_tokenizer = nlp.data.GPT2BPETokenizer()
mx_tokens = [vocab.bos_token] + mx_tokenizer(texts) + [vocab.eos_token]
mx_data = vocab[mx_tokens]
print(mx_tokens)
print(vocab[mx_tokens])
print(torch_tokens)
assert mx_data == torch_tokens.tolist()

mx_out = bert(mx.nd.array([mx_data]))
print('stdev = ', np.std(mx_out.asnumpy() - pytorch_out))
mx.test_utils.assert_almost_equal(mx_out.asnumpy(), pytorch_out, atol=1e-3, rtol=1e-3)
mx.test_utils.assert_almost_equal(mx_out.asnumpy(), pytorch_out, atol=5e-6, rtol=5e-6)

bert.save_parameters(os.path.join(ckpt_dir, args.model + '.params'))
with io.open(os.path.join(ckpt_dir, args.model + '.vocab'), 'w', encoding='utf-8') as f:
f.write(vocab.to_json())
74 changes: 74 additions & 0 deletions scripts/bert/conversion_tools/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# coding: utf-8

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Utility functions for BERT."""

import logging
import collections
import hashlib
import io

import mxnet as mx
import gluonnlp as nlp

__all__ = ['tf_vocab_to_gluon_vocab', 'load_text_vocab']


def tf_vocab_to_gluon_vocab(tf_vocab):
special_tokens = ['[UNK]', '[PAD]', '[SEP]', '[MASK]', '[CLS]']
assert all(t in tf_vocab for t in special_tokens)
counter = nlp.data.count_tokens(tf_vocab.keys())
vocab = nlp.vocab.BERTVocab(counter, token_to_idx=tf_vocab)
return vocab


def get_hash(filename):
sha1 = hashlib.sha1()
with open(filename, 'rb') as f:
while True:
data = f.read(1048576)
if not data:
break
sha1.update(data)
return sha1.hexdigest(), str(sha1.hexdigest())[:8]


def read_tf_checkpoint(path):
"""read tensorflow checkpoint"""
from tensorflow.python import pywrap_tensorflow
tensors = {}
reader = pywrap_tensorflow.NewCheckpointReader(path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in sorted(var_to_shape_map):
tensor = reader.get_tensor(key)
tensors[key] = tensor
return tensors

def load_text_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with io.open(vocab_file, 'r') as reader:
while True:
token = reader.readline()
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
25 changes: 25 additions & 0 deletions scripts/bert/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,31 @@ The following pre-trained BERT models are available from the **gluonnlp.model.ge

where **bert_12_768_12** refers to the BERT BASE model, and **bert_24_1024_16** refers to the BERT LARGE model.

.. code-block:: python
import gluonnlp as nlp; import mxnet as mx;
model, vocab = nlp.model.get_model('bert_12_768_12', dataset_name='book_corpus_wiki_en_uncased', use_classifier=False);
tokenizer = nlp.data.BERTTokenizer(vocab, lower=True);
transform = nlp.data.BERTSentenceTransform(tokenizer, max_seq_length=512, pair=False, pad=False);
sample = transform(['Hello world!']);
words, valid_len, segments = mx.nd.array([sample[0]]), mx.nd.array([sample[1]]), mx.nd.array([sample[2]]);
seq_encoding, cls_encoding = model(words, segments, valid_len);
Additionally, GluonNLP supports the "`RoBERTa <https://arxiv.org/abs/1907.11692>`_" model:

+-----------------------------------------+-------------------+--------------------+
| | roberta_12_768_12 | roberta_24_1024_16 |
+=========================================+===================+====================+
| openwebtext_ccnews_stories_books_cased |||
+-----------------------------------------+-------------------+--------------------+

.. code-block:: python
import gluonnlp as nlp; import mxnet as mx;
model, vocab = nlp.model.get_model('roberta_12_768_12', dataset_name='openwebtext_ccnews_stories_books_cased');
tokenizer = nlp.data.GPT2BPETokenizer();
text = [vocab.bos_token] + tokenizer('Hello world!') + [vocab.eos_token];
seq_encoding = model(mx.nd.array([vocab[text]]))
.. hint::

Expand Down
19 changes: 18 additions & 1 deletion scripts/bert/pretraining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

__all__ = ['get_model_loss', 'get_pretrain_data_npz', 'get_dummy_dataloader',
'save_parameters', 'save_states', 'evaluate', 'forward', 'split_and_load',
'get_argparser', 'get_pretrain_data_text', 'generate_dev_set']
'get_argparser', 'get_pretrain_data_text', 'generate_dev_set', 'profile']

def get_model_loss(ctx, model, pretrained, dataset_name, vocab, dtype,
ckpt_dir=None, start_step=None):
Expand Down Expand Up @@ -505,3 +505,20 @@ def generate_dev_set(tokenizer, vocab, cache_file, args):
1, args.num_data_workers,
worker_pool, cache_file))
logging.info('Done generating validation set on rank 0.')

def profile(curr_step, start_step, end_step, profile_name='profile.json',
early_exit=True):
"""profile the program between [start_step, end_step)."""
if curr_step == start_step:
mx.nd.waitall()
mx.profiler.set_config(profile_memory=False, profile_symbolic=True,
profile_imperative=True, filename=profile_name,
aggregate_stats=True)
mx.profiler.set_state('run')
elif curr_step == end_step:
mx.nd.waitall()
mx.profiler.set_state('stop')
logging.info(mx.profiler.dumps())
mx.profiler.dump()
if early_exit:
exit()
3 changes: 1 addition & 2 deletions scripts/bert/run_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@
import mxnet as mx
import gluonnlp as nlp

from utils import profile
from fp16_utils import FP16Trainer
from pretraining_utils import get_model_loss, get_pretrain_data_npz, get_dummy_dataloader
from pretraining_utils import log, evaluate, forward, split_and_load, get_argparser
from pretraining_utils import save_parameters, save_states
from pretraining_utils import save_parameters, save_states, profile

# arg parser
parser = get_argparser()
Expand Down
3 changes: 1 addition & 2 deletions scripts/bert/run_pretraining_hvd.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,10 @@
import mxnet as mx
import gluonnlp as nlp

from utils import profile
from fp16_utils import FP16Trainer
from pretraining_utils import get_model_loss, get_pretrain_data_npz, get_dummy_dataloader
from pretraining_utils import split_and_load, log, evaluate, forward, get_argparser
from pretraining_utils import save_parameters, save_states
from pretraining_utils import save_parameters, save_states, profile
from pretraining_utils import get_pretrain_data_text, generate_dev_set

# parser
Expand Down
2 changes: 2 additions & 0 deletions src/gluonnlp/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def _slice_pad_length(num_items, length, overlap=0):
'book_corpus_wiki_en_uncased': 'a66073971aa0b1a262453fe51342e57166a8abcf',
'openwebtext_book_corpus_wiki_en_uncased':
'a66073971aa0b1a262453fe51342e57166a8abcf',
'openwebtext_ccnews_stories_books_cased':
'2b804f8f90f9f93c07994b703ce508725061cf43',
'wiki_multilingual_cased': '0247cb442074237c38c62021f36b7a4dbd2e55f7',
'wiki_cn_cased': 'ddebd8f3867bca5a61023f73326fb125cf12b4f5',
'wiki_multilingual_uncased': '2b2514cc539047b9179e9d98a4e68c36db05c97a',
Expand Down
2 changes: 2 additions & 0 deletions src/gluonnlp/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def get_model(name, **kwargs):
'transformer_en_de_512': transformer_en_de_512,
'bert_12_768_12' : bert_12_768_12,
'bert_24_1024_16' : bert_24_1024_16,
'roberta_12_768_12' : roberta_12_768_12,
'roberta_24_1024_16' : roberta_24_1024_16,
'ernie_12_768_12' : ernie_12_768_12}
name = name.lower()
if name not in models:
Expand Down
Loading

0 comments on commit 3922d06

Please sign in to comment.