diff --git a/scripts/bert/compare_tf_gluon_model.py b/scripts/bert/compare_tf_gluon_model.py index e507f0e773..e35c216d52 100644 --- a/scripts/bert/compare_tf_gluon_model.py +++ b/scripts/bert/compare_tf_gluon_model.py @@ -26,17 +26,16 @@ import numpy as np import mxnet as mx import gluonnlp as nlp -from gluonnlp.data import TSVDataset -from gluonnlp.data import BERTTokenizer -from gluon.data import BERTSentenceTransform parser = argparse.ArgumentParser(description='Comparison script for BERT model in Tensorflow' - 'and that in Gluon') + 'and that in Gluon. This script works with ' + 'google/bert@f39e881b') parser.add_argument('--input_file', type=str, default='input.txt', help='sample input file for testing. Default is input.txt') parser.add_argument('--tf_bert_repo_dir', type=str, default='~/bert/', help='path to the original Tensorflow bert repository. ' + 'The repo should be at f39e881b. ' 'Default is ~/bert/') parser.add_argument('--tf_model_dir', type=str, default='~/uncased_L-12_H-768_A-12/', @@ -48,15 +47,17 @@ help='gluon dataset name. Default is book_corpus_wiki_en_uncased') parser.add_argument('--gluon_model', type=str, default='bert_12_768_12', help='gluon model name. Default is bert_12_768_12') +parser.add_argument('--gluon_parameter_file', type=str, default=None, + help='gluon parameter file name.') args = parser.parse_args() input_file = os.path.expanduser(args.input_file) tf_bert_repo_dir = os.path.expanduser(args.tf_bert_repo_dir) tf_model_dir = os.path.expanduser(args.tf_model_dir) -vocab_file = tf_model_dir + 'vocab.txt' -bert_config_file = tf_model_dir + 'bert_config.json' -init_checkpoint = tf_model_dir + 'bert_model.ckpt' +vocab_file = os.path.join(tf_model_dir, 'vocab.txt') +bert_config_file = os.path.join(tf_model_dir, 'bert_config.json') +init_checkpoint = os.path.join(tf_model_dir, 'bert_model.ckpt') do_lower_case = not args.cased max_length = 128 @@ -130,13 +131,24 @@ bert, vocabulary = nlp.model.get_model(args.gluon_model, dataset_name=args.gluon_dataset, - pretrained=True, use_pooler=False, - use_decoder=False, use_classifier=False) + pretrained=not args.gluon_parameter_file, + use_pooler=False, + use_decoder=False, + use_classifier=False) +if args.gluon_parameter_file: + try: + bert.cast('float16') + bert.load_parameters(args.gluon_parameter_file, ignore_extra=True) + bert.cast('float32') + except AssertionError: + bert.cast('float32') + bert.load_parameters(args.gluon_parameter_file, ignore_extra=True) + print(bert) -tokenizer = BERTTokenizer(vocabulary, lower=do_lower_case) -dataset = TSVDataset(input_file, field_separator=nlp.data.Splitter(' ||| ')) +tokenizer = nlp.data.BERTTokenizer(vocabulary, lower=do_lower_case) +dataset = nlp.data.TSVDataset(input_file, field_separator=nlp.data.Splitter(' ||| ')) -trans = BERTSentenceTransform(tokenizer, max_length) +trans = nlp.data.BERTSentenceTransform(tokenizer, max_length) dataset = dataset.transform(trans) bert_dataloader = mx.gluon.data.DataLoader(dataset, batch_size=1, @@ -152,7 +164,5 @@ b = out[0][:length].asnumpy() print('stdev = %s' % (np.std(a - b))) - mx.test_utils.assert_almost_equal(a, b, atol=1e-4, rtol=1e-4) - mx.test_utils.assert_almost_equal(a, b, atol=1e-5, rtol=1e-5) mx.test_utils.assert_almost_equal(a, b, atol=5e-6, rtol=5e-6) break diff --git a/scripts/bert/convert_tf_model.py b/scripts/bert/convert_tf_model.py index 1bd9f03a23..f2619822bb 100644 --- a/scripts/bert/convert_tf_model.py +++ b/scripts/bert/convert_tf_model.py @@ -32,11 +32,11 @@ help='BERT model name. options are bert_12_768_12 and bert_24_1024_16.' 'Default is bert_12_768_12') parser.add_argument('--tf_checkpoint_dir', type=str, - default='/home/ubuntu/cased_L-12_H-768_A-12/', + default=os.path.join('~', 'cased_L-12_H-768_A-12/'), help='Path to Tensorflow checkpoint folder. ' 'Default is /home/ubuntu/cased_L-12_H-768_A-12/') parser.add_argument('--out_dir', type=str, - default='/home/ubuntu/output/', + default=os.path.join('~', 'output'), help='Path to output folder. The folder must exist. ' 'Default is /home/ubuntu/output/') parser.add_argument('--debug', action='store_true', help='debugging mode') @@ -49,17 +49,18 @@ vocab, reserved_token_idx_map = convert_vocab(vocab_path) # vocab serialization -tmp_file_path = os.path.join(args.out_dir, 'tmp') +tmp_file_path = os.path.expanduser(os.path.join(args.out_dir, 'tmp')) with open(tmp_file_path, 'w') as f: f.write(vocab.to_json()) hash_full, hash_short = get_hash(tmp_file_path) -gluon_vocab_path = os.path.join(args.out_dir, hash_short + '.vocab') +gluon_vocab_path = os.path.expanduser(os.path.join(args.out_dir, hash_short + '.vocab')) with open(gluon_vocab_path, 'w') as f: f.write(vocab.to_json()) logging.info('vocab file saved to %s. hash = %s', gluon_vocab_path, hash_full) # load tf model -tf_checkpoint_file = os.path.join(args.tf_checkpoint_dir, 'bert_model.ckpt') +tf_checkpoint_file = os.path.expanduser( + os.path.join(args.tf_checkpoint_dir, 'bert_model.ckpt')) logging.info('loading Tensorflow checkpoint %s ...', tf_checkpoint_file) tf_tensors = read_tf_checkpoint(tf_checkpoint_file) tf_names = sorted(tf_tensors.keys()) @@ -177,7 +178,7 @@ # param serialization bert.save_parameters(tmp_file_path) hash_full, hash_short = get_hash(tmp_file_path) -gluon_param_path = os.path.join(args.out_dir, hash_short + '.params') +gluon_param_path = os.path.expanduser(os.path.join(args.out_dir, hash_short + '.params')) logging.info('param saved to %s. hash = %s', gluon_param_path, hash_full) bert.save_parameters(gluon_param_path) mx.nd.waitall()