Skip to content

Commit

Permalink
feat(easytransfer): sync codebase with easytransfer-v0.1.4
Browse files Browse the repository at this point in the history
  • Loading branch information
ScarletPan committed Jan 5, 2021
1 parent de4fc99 commit 1bf7635
Show file tree
Hide file tree
Showing 19 changed files with 126 additions and 83 deletions.
4 changes: 4 additions & 0 deletions easytransfer/app_zoo/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def build_train_resource_files(self, flags):
json.dump(self.__dict__, f)
if hasattr(self, "pretrain_model_name_or_path"):
copy_pretrain_model_files_to_dir(self.pretrain_model_name_or_path, flags.checkpointDir)
if self.label_enumerate_values and "," in self.label_enumerate_values:
label_dict = {label: idx for idx, label in enumerate(self.label_enumerate_values.split(","))}
with tf.gfile.GFile(os.path.join(flags.checkpointDir, "label_mapping.json"), mode='w') as f:
json.dump(label_dict, f)

def build_preprocess_config(self, flags):
first_sequence, second_sequence, label_name = \
Expand Down
1 change: 0 additions & 1 deletion easytransfer/app_zoo/text_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import tensorflow as tf
from easytransfer import preprocessors, model_zoo
from easytransfer.app_zoo.base import ApplicationModel
Expand Down
13 changes: 13 additions & 0 deletions easytransfer/app_zoo_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
sys.path.append("./")

import tensorflow as tf
try:
import tensorflow_io as tfio
except:
pass
from easytransfer import FLAGS
from easytransfer.app_zoo import get_application_model
from easytransfer.app_zoo.app_config import AppConfig
Expand Down Expand Up @@ -81,6 +85,15 @@


def main():
# Here is a hack for DSW access OSS
for argname in ["inputTable", "outputTable", "checkpointDir", "checkpointPath", "exportDirBase"]:
arg = getattr(_APP_FLAGS, argname)
if arg:
arg = arg.replace("\\x01", "\x01").replace("\\x02", "\x02")
setattr(_APP_FLAGS, argname, arg)
FLAGS.modelZooBasePath = FLAGS.modelZooBasePath.replace("\\x01", "\x01").replace("\\x02", "\x02")

# Main function start
config = AppConfig(mode=FLAGS.mode, flags=_APP_FLAGS)
app = get_application_model(config)
app.run()
Expand Down
5 changes: 1 addition & 4 deletions easytransfer/datasets/csv_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ def __init__(self, output_glob, output_schema, input_queue=None, **kwargs):
job_name = 'DistTableWriter'
super(CSVWriter, self).__init__(job_name, 1, input_queue)

if six.PY3:
self.writer = open(output_glob, "w", encoding='utf8')
elif six.PY2:
self.writer = open(output_glob, "w")
self.writer = tf.gfile.Open(output_glob, "w")

self.output_schema = output_schema

Expand Down
8 changes: 0 additions & 8 deletions easytransfer/datasets/tfrecord_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,6 @@ def input_fn():
d = d.shard(len(self.worker_hosts.split(',')), self.task_index)
d = d.repeat()
d = d.shuffle(buffer_size=len(self.input_fps))
"""
def passthrough(path):
return tf.Print(path, [path], message='Path=')
d = d.map(passthrough, num_parallel_calls=1)
"""
cycle_length = min(4, len(self.input_fps))
d = d.apply(
tf.data.experimental.parallel_interleave(
Expand All @@ -130,8 +124,6 @@ def passthrough(path):

d = d.shuffle(buffer_size=self.shuffle_buffer_size)



else:
d = tf.data.TFRecordDataset(self.input_fps)
# Since we evaluate for a fixed number of steps we don't want to encounter
Expand Down
13 changes: 13 additions & 0 deletions easytransfer/feat_ext_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
import json
import os
import tensorflow as tf
try:
import tensorflow_io as tfio
except:
pass
from easytransfer import Config, FLAGS
from easytransfer.app_zoo.app_utils import get_all_columns_name, get_selected_columns_schema
from easytransfer.app_zoo.feature_extractor import BertFeatureExtractor
Expand Down Expand Up @@ -139,6 +143,15 @@ def __init__(self):


def main():
# Here is a hack for DSW access OSS
for argname in ["inputTable", "outputTable", "modelName"]:
arg = getattr(_APP_FLAGS, argname)
if arg:
arg = arg.replace("\\x01", "\x01").replace("\\x02", "\x02")
setattr(_APP_FLAGS, argname, arg)
FLAGS.modelZooBasePath = FLAGS.modelZooBasePath.replace("\\x01", "\x01").replace("\\x02", "\x02")

# Main function start
config = BertFeatConfig()
app = BertFeatureExtractor(user_defined_config=config)
app.run()
Expand Down
24 changes: 16 additions & 8 deletions easytransfer/layers/encoder_decoder_whale.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import tensorflow as tf
import whale as wh
from tensorflow.python.layers.base import Layer
from .activations import gelu_new
Expand Down Expand Up @@ -30,7 +31,7 @@ class Encoder(Layer):
def __init__(self, config, **kwargs):
super(Encoder, self).__init__(**kwargs)
self.layer = [Block(config, name="layer_{}".format(i)) for i in range(config.num_hidden_layers)]
#self.layer = [Block(config, name="layer_{}".format(i)) for i in range(3)]
self.num_layers = config.num_hidden_layers

def _stage_call(self, layer_index, all_hidden_states, all_att_outputs, hidden_states, attention_mask, training):
layer_output, att_output = self.layer[layer_index]([hidden_states, attention_mask], training=training)
Expand All @@ -40,23 +41,30 @@ def _stage_call(self, layer_index, all_hidden_states, all_att_outputs, hidden_st
return all_hidden_states, all_att_outputs, hidden_states

def call(self, inputs, training=False):
tf.logging.info("***************Inside stage to split model**********")
hidden_states, attention_mask = inputs

all_hidden_states = ()
all_att_outputs = ()

bert_large_layers_count = 12
assert len(self.layer) == bert_large_layers_count
bert_base_layers_count = self.num_layers
assert len(self.layer) == bert_base_layers_count
# Use default scope.
for i in range(0, 2):
for i in range(0, 3):
all_hidden_states, all_att_outputs, hidden_states = self._stage_call(i, all_hidden_states, all_att_outputs, hidden_states, attention_mask, training)

# with wh.stage():
# for i in range(each_stage_layers_count, 2*each_stage_layers_count):
# all_hidden_states, all_att_outputs, hidden_states = self._stage_call(i, all_hidden_states, all_att_outputs, hidden_states, attention_mask, training)
with wh.stage():
for i in range(3, 6):
all_hidden_states, all_att_outputs, hidden_states = self._stage_call(i, all_hidden_states, all_att_outputs, hidden_states, attention_mask, training)
wh.current_scope_as_default()

with wh.stage():
for i in range(6, 9):
all_hidden_states, all_att_outputs, hidden_states = self._stage_call(i, all_hidden_states, all_att_outputs, hidden_states, attention_mask, training)
wh.current_scope_as_default()

with wh.stage():
for i in range(2, 12):
for i in range(9, 12):
all_hidden_states, all_att_outputs, hidden_states = self._stage_call(i, all_hidden_states, all_att_outputs, hidden_states, attention_mask, training)
wh.current_scope_as_default()

Expand Down
10 changes: 0 additions & 10 deletions easytransfer/model_zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ def get_pretrained_model(pretrain_model_name_or_path, **kwargs):
elif model_type == 'videobert':
from .modeling_videobert import VideoBertPreTrainedModel
return VideoBertPreTrainedModel.get(pretrain_model_name_or_path, **kwargs)
elif model_type == "factorizedbert":
from .modeling_factorizedbert import FactorizedBertPreTrainedModel
return FactorizedBertPreTrainedModel.get(pretrain_model_name_or_path, **kwargs)
else:
raise NotImplementedError
else:
Expand Down Expand Up @@ -67,9 +64,6 @@ def get_pretrained_model(pretrain_model_name_or_path, **kwargs):
elif model_type == 'videobert':
from .modeling_videobert import VideoBertPreTrainedModel
return VideoBertPreTrainedModel.get(pretrain_model_name_or_path, **kwargs)
elif model_type == "factorizedbert":
from .modeling_factorizedbert import FactorizedBertPreTrainedModel
return FactorizedBertPreTrainedModel.get(pretrain_model_name_or_path, **kwargs)
else:
raise ValueError("model_type should be in bert, roberta, albert, imagebert, videobert")

Expand All @@ -94,10 +88,6 @@ def get_config_path(model_type, pretrain_model_name_or_path):
from .modeling_videobert import VideoBertPreTrainedModel
config_path = VideoBertPreTrainedModel.pretrained_config_archive_map[
pretrain_model_name_or_path]
elif model_type == "factorizedbert":
from .modeling_factorizedbert import FactorizedBertPreTrainedModel
config_path = FactorizedBertPreTrainedModel.pretrained_config_archive_map[
pretrain_model_name_or_path]
else:
raise ValueError("model_type should be in bert, roberta, albert, imagebert, videobert")

Expand Down
4 changes: 2 additions & 2 deletions easytransfer/postprocessors/classification_postprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from easytransfer.engines.distribution import Process
import numpy as np
import six
import numpy as np
from easytransfer.engines.distribution import Process


class ClassificationPostprocessor(Process):
Expand Down
4 changes: 4 additions & 0 deletions easytransfer/postprocessors/labeling_postprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def process(self, in_data):
continue
if token_orig_idx == prev_token_idx:
continue
if token_pred == -1 or token_pred > len(self.idx_label_map):
token_pred = len(self.idx_label_map) - 1
if self.idx_label_map[token_pred] == "[CLS]" or self.idx_label_map[token_pred] == "[SEP]":
token_pred = len(self.idx_label_map) - 1
final_pred.append(self.idx_label_map[token_pred])
prev_token_idx = token_orig_idx
raw_sequence_length = max(tok_to_orig_index) + 1
Expand Down
20 changes: 16 additions & 4 deletions easytransfer/preprocessors/comprehension_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,19 @@ def process(self, inputs):
for features in all_feature_list:
ret[key].append(getattr(features, key))

for key, val in ret.items():
ret[key] = np.array(val)

return ret
total_sample_num = len(ret["input_ids"])
if hasattr(self.config, "preprocess_batch_size"):
batch_size = self.config.preprocess_batch_size
elif hasattr(self.config, "predict_batch_size"):
batch_size = self.config.predict_batch_size
else:
batch_size = 12
for i in range(total_sample_num // batch_size + 1):
st = i * batch_size
end = (i + 1) * batch_size
if st >= total_sample_num:
continue
new_ret = dict()
for key, val in ret.items():
new_ret[key] = np.array(val[st:end])
self.put(new_ret)
5 changes: 4 additions & 1 deletion easytransfer/preprocessors/labeling_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


import tensorflow as tf
import traceback
from collections import OrderedDict
from .preprocessor import Preprocessor, PreprocessorConfig
from .tokenization import convert_to_unicode
Expand Down Expand Up @@ -81,6 +82,8 @@ def convert_example_to_features(self, items):
tok_to_orig_index = [-1]
for i, token in enumerate(content_tokens):
sub_tokens = self.config.tokenizer.tokenize(token)
if not sub_tokens:
sub_tokens = ["[UNK]"]
all_tokens.extend(sub_tokens)
tok_to_orig_index.extend([i] * len(sub_tokens))
if label_tags is None:
Expand Down Expand Up @@ -108,7 +111,7 @@ def convert_example_to_features(self, items):
assert len(input_mask) == self.config.sequence_length
assert len(segment_ids) == self.config.sequence_length
assert len(label_ids) == self.config.sequence_length
assert max(tok_to_orig_index) == len(content_tokens) - 1
assert max(tok_to_orig_index) == len(content_tokens) - 1, "Abnormal line: {}".format(items)
return ' '.join([str(t) for t in input_ids]), \
' '.join([str(t) for t in input_mask]), \
' '.join([str(t) for t in segment_ids]), \
Expand Down
45 changes: 25 additions & 20 deletions easytransfer/preprocessors/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,27 +110,32 @@ def __init__(self, **kwargs):

if "/" not in pretrain_model_name_or_path:
model_type = pretrain_model_name_or_path.split("-")[1]
if six.PY2:
import errno
def mkdir_p(path):
try:
os.makedirs(path)
except OSError as exc: # Python >2.5 (except OSError, exc: for Python <2.5)
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
mkdir_p(os.path.join(FLAGS.modelZooBasePath, model_type))
if tf.gfile.Exists(os.path.join(FLAGS.modelZooBasePath, model_type,
pretrain_model_name_or_path, "config.json")):
# If exists directory, not download
pass
else:
os.makedirs(os.path.join(FLAGS.modelZooBasePath, model_type), exist_ok=True)

des_path = os.path.join(os.path.join(FLAGS.modelZooBasePath, model_type),
pretrain_model_name_or_path + ".tgz")
if not os.path.exists(des_path):
tf.logging.info("********** Begin to download to {} **********".format(des_path))
os.system(
'wget -O ' + des_path + ' https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/eztransfer_modelzoo/' + model_type + '/' + pretrain_model_name_or_path + ".tgz")
os.system('tar -zxvf ' + des_path + ' --directory ' + FLAGS.modelZooBasePath + "/" + model_type)
if six.PY2:
import errno
def mkdir_p(path):
try:
os.makedirs(path)
except OSError as exc: # Python >2.5 (except OSError, exc: for Python <2.5)
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
mkdir_p(os.path.join(FLAGS.modelZooBasePath, model_type))
else:
os.makedirs(os.path.join(FLAGS.modelZooBasePath, model_type), exist_ok=True)

des_path = os.path.join(os.path.join(FLAGS.modelZooBasePath, model_type),
pretrain_model_name_or_path + ".tgz")
if not os.path.exists(des_path):
tf.logging.info("********** Begin to download to {} **********".format(des_path))
os.system(
'wget -O ' + des_path + ' https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/eztransfer_modelzoo/' + model_type + '/' + pretrain_model_name_or_path + ".tgz")
os.system('tar -zxvf ' + des_path + ' --directory ' + FLAGS.modelZooBasePath + "/" + model_type)

if "train" in self.mode:
model_dir = kwargs['model_dir']
Expand Down
30 changes: 15 additions & 15 deletions scripts/CLUE_GLUE_SuperGLUE_benchmark/run_tasks.sh
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
#!/usr/bin/env bash

wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/tutorial/clue_glue_superglue_benchmark/clue_datasets.tgz
wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/tutorial/clue_glue_superglue_benchmark/glue_datasets.tgz
wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/tutorial/clue_glue_superglue_benchmark/superglue_datasets.tgz
tar -zxf clue_datasets.tgz
tar -zxf glue_datasets.tgz
tar -zxf superglue_datasets.tgz
mkdir datasets
mv clue_datasets/* datasets
mv glue_datasets/* datasets
mv superglue_datasets/* datasets
rm -rf *_datasets
rm *.tgz
#wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/tutorial/clue_glue_superglue_benchmark/clue_datasets.tgz
#wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/tutorial/clue_glue_superglue_benchmark/glue_datasets.tgz
#wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/tutorial/clue_glue_superglue_benchmark/superglue_datasets.tgz
#tar -zxf clue_datasets.tgz
#tar -zxf glue_datasets.tgz
#tar -zxf superglue_datasets.tgz
#mkdir datasets
#mv clue_datasets/* datasets
#mv glue_datasets/* datasets
#mv superglue_datasets/* datasets
#rm -rf *_datasets
#rm *.tgz

export CUDA_VISIBLE_DEVICES="1"
export CUDA_VISIBLE_DEVICES="2"
#CLUE---> AFQMC, CMNLI, CSL, IFLYTEK, TNEWS
#GLUE---> CoLA, MRPC, QQP, RTE, SST-2
#SuperGLUE---> BoolQ, CB, COPA, WiC, WSC
Expand All @@ -25,11 +25,11 @@ python main_finetune.py --workerGPU=1 \
--train_input_fp=datasets/${task_name}/train.csv \
--eval_input_fp=datasets/${task_name}/dev.csv \
--predict_input_fp=datasets/${task_name}/test.csv \
--predict_checkpoint_path=CLUEWSC_model_dir/model.ckpt-778 \
--predict_checkpoint_path=${task_name}_model_dir/model.ckpt-0 \
--pretrain_model_name_or_path=hit-roberta-large-zh \
--train_batch_size=16 \
--num_epochs=10 \
--model_dir=${task_name}_model_dir_2 \
--model_dir=${task_name}_model_dir \
--learning_rate=3e-5 \


3 changes: 1 addition & 2 deletions scripts/meta_finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
sys.path.append("../..")
import os
Expand Down Expand Up @@ -53,7 +52,7 @@ def build_eval_metrics(self, logits, labels):

def main(_):
app = Application()
if FLAGS.usePAI:
if "PAI" in tf.__version__:
train_reader = OdpsTableReader(input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size)
eval_reader = OdpsTableReader(input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size)
else:
Expand Down
2 changes: 1 addition & 1 deletion scripts/meta_finetune/meta_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def main(_):

app = Application()

if FLAGS.usePAI:
if "PAI" in tf.__version__:
train_reader = OdpsTableReader(input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size)
eval_reader = OdpsTableReader(input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size)
app.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader)
Expand Down
Loading

0 comments on commit 1bf7635

Please sign in to comment.