Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[API] Make BERT a hybrid block (#877)
Browse files Browse the repository at this point in the history
* change BaseTransformerEncoder forward

* fix

* merge _forward to hybrid_forward

* if arange_like is not avaiable

* fix

* make transformer encoder a true hybridblock

* fix lint

* fix bug

* use zeroslike and infer_range to avoid backprop problem

* update test

* Revert "update test"

This reverts commit d024a0c.

* more printing

* fix test case

* make bert hybrid

* add legacy model

* fix lint

* fix lint

* revert change for default parameter

* fix arange dtype

* fix lint

* revert mokey patch in NMT interface

* fix dtype

* fix bug in arange

* remove legacy model

* also update bert scripts

* fix lint

* fix typo

* remove hybridbert test
  • Loading branch information
eric-haibin-lin committed Aug 20, 2019
1 parent e908dd8 commit 7847de2
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 1,192 deletions.
19 changes: 10 additions & 9 deletions scripts/bert/export/export.py → scripts/bert/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@

import mxnet as mx
import gluonnlp as nlp
from hybrid_bert import get_hybrid_model
from hybrid_bert import HybridBERTClassifier, HybridBERTRegression, HybridBERTForQA
from gluonnlp.model import get_model
from model.classification import BERTClassifier, BERTRegression
from model.qa import BertForQA

parser = argparse.ArgumentParser(description='Export hybrid BERT base model.')

Expand Down Expand Up @@ -126,35 +127,35 @@
seq_length = args.seq_length

if args.task == 'classification':
bert, _ = get_hybrid_model(
bert, _ = get_model(
name=args.model_name,
dataset_name=args.dataset_name,
pretrained=False,
use_pooler=True,
use_decoder=False,
use_classifier=False,
seq_length=args.seq_length)
net = HybridBERTClassifier(bert, num_classes=2, dropout=args.dropout)
net = BERTClassifier(bert, num_classes=2, dropout=args.dropout)
elif args.task == 'regression':
bert, _ = get_hybrid_model(
bert, _ = get_model(
name=args.model_name,
dataset_name=args.dataset_name,
pretrained=False,
use_pooler=True,
use_decoder=False,
use_classifier=False,
seq_length=args.seq_length)
net = HybridBERTRegression(bert, dropout=args.dropout)
net = BERTRegression(bert, dropout=args.dropout)
elif args.task == 'question_answering':
bert, _ = get_hybrid_model(
bert, _ = get_model(
name=args.model_name,
dataset_name=args.dataset_name,
pretrained=False,
use_pooler=False,
use_decoder=False,
use_classifier=False,
seq_length=args.seq_length)
net = HybridBERTForQA(bert)
net = BertForQA(bert)
else:
raise ValueError('unknown task: %s'%args.task)

Expand All @@ -163,7 +164,7 @@
else:
net.initialize()
warnings.warn('--model_parameters is not provided. The parameter checkpoint (.params) '
'file will be created based on default parameter intialization.')
'file will be created based on default parameter initialization.')

net.hybridize(static_alloc=True, static_shape=True)

Expand Down
22 changes: 0 additions & 22 deletions scripts/bert/export/__init__.py

This file was deleted.

Loading

0 comments on commit 7847de2

Please sign in to comment.