From 8d0cc211361e81a648bf846d8ec84225273db0e4 Mon Sep 17 00:00:00 2001 From: kamalkraj Date: Mon, 9 Dec 2019 04:48:16 -0800 Subject: [PATCH] sts-b support added --- classifier_data_lib.py | 20 +++++++++++++++----- input_pipeline.py | 2 +- run_classifer.py | 42 ++++++++++++++++++++++++++++++++++-------- 3 files changed, 50 insertions(+), 14 deletions(-) diff --git a/classifier_data_lib.py b/classifier_data_lib.py index f917288..bbc117c 100644 --- a/classifier_data_lib.py +++ b/classifier_data_lib.py @@ -659,9 +659,10 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, label_id=0, is_real_example=False) - label_map = {} - for (i, label) in enumerate(label_list): - label_map[label] = i + if FLAGS.classification_task_name.lower() != "sts": + label_map = {} + for (i, label) in enumerate(label_list): + label_map[label] = i tokens_a = tokenizer.tokenize(example.text_a) tokens_b = None @@ -729,7 +730,11 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length - label_id = label_map[example.label] + if FLAGS.classification_task_name.lower() != "sts": + label_id = label_map[example.label] + else: + label_id = example.label + if ex_index < 5: logging.info("*** Example ***") logging.info("guid: %s" % (example.guid)) @@ -766,12 +771,17 @@ def file_based_convert_examples_to_features(examples, label_list, def create_int_feature(values): f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) return f + + def create_float_feature(values): + f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) + return f features = collections.OrderedDict() features["input_ids"] = create_int_feature(feature.input_ids) features["input_mask"] = create_int_feature(feature.input_mask) features["segment_ids"] = create_int_feature(feature.segment_ids) - features["label_ids"] = create_int_feature([feature.label_id]) + features["label_ids"] = create_float_feature([feature.label_id])\ + if FLAGS.classification_task_name.lower() == "sts" else create_int_feature([feature.label_id]) features["is_real_example"] = create_int_feature( [int(feature.is_real_example)]) diff --git a/input_pipeline.py b/input_pipeline.py index 3273040..5ef8310 100644 --- a/input_pipeline.py +++ b/input_pipeline.py @@ -148,7 +148,7 @@ def create_classifier_dataset(file_path, 'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64), 'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64), 'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64), - 'label_ids': tf.io.FixedLenFeature([], tf.int64), + 'label_ids': tf.io.FixedLenFeature([], tf.int64) if FLAGS.task_name.lower() != "sts" else tf.io.FixedLenFeature([], tf.float32), 'is_real_example': tf.io.FixedLenFeature([], tf.int64), } input_fn = file_based_input_fn_builder(file_path, name_to_features) diff --git a/run_classifer.py b/run_classifer.py index d12f806..b7471cc 100644 --- a/run_classifer.py +++ b/run_classifer.py @@ -160,6 +160,19 @@ def classification_loss_fn(labels, logits): return classification_loss_fn +def get_loss_fn_v2(loss_factor=1.0): + """Gets the loss function for STS.""" + + def sts_loss_fn(labels, logits): + """STS loss""" + logits = tf.squeeze(logits, [-1]) + per_example_loss = tf.square(logits - labels) + loss = tf.reduce_mean(per_example_loss) + loss *= loss_factor + return loss + + return sts_loss_fn + def get_model(albert_config, max_seq_length, num_labels, init_checkpoint, learning_rate, num_train_steps, num_warmup_steps,loss_multiplier): """Returns keras fuctional model""" @@ -217,10 +230,13 @@ def get_model(albert_config, max_seq_length, num_labels, init_checkpoint, learni beta_2=0.999, epsilon=FLAGS.adam_epsilon, exclude_from_weight_decay=['layer_norm', 'bias']) - - loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) - - model.compile(optimizer=optimizer,loss=loss_fct,metrics=['accuracy']) + + if FLAGS.task_name.lower() == 'sts': + loss_fct = tf.keras.losses.MeanSquaredError() + model.compile(optimizer=optimizer,loss=loss_fct,metrics=['mse']) + else: + loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + model.compile(optimizer=optimizer,loss=loss_fct,metrics=['accuracy']) return model @@ -316,10 +332,16 @@ def main(_): custom_callbacks = [summary_callback, checkpoint_callback] def metric_fn(): - return tf.keras.metrics.SparseCategoricalAccuracy(dtype=tf.float32) + if FLAGS.task_name.lower() == "sts": + return tf.keras.metrics.MeanSquaredError(dtype=tf.float32) + else: + return tf.keras.metrics.SparseCategoricalAccuracy(dtype=tf.float32) if FLAGS.custom_training_loop: - loss_fn = get_loss_fn(num_labels,loss_factor=loss_multiplier) + if FLAGS.task_name.lower() == "sts": + loss_fn = get_loss_fn_v2(loss_factor=loss_multiplier) + else: + loss_fn = get_loss_fn(num_labels,loss_factor=loss_multiplier) model = run_customized_training_loop(strategy = strategy, model = model, loss_fn = loss_fn, @@ -396,8 +418,12 @@ def metric_fn(): with strategy.scope(): logits = model.predict(prediction_dataset) - predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) - probabilities = tf.nn.softmax(logits, axis=-1) + if FLAGS.task_name.lower() == "sts": + predictions = logits + probabilities = logits + else: + predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) + probabilities = tf.nn.softmax(logits, axis=-1) output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") output_submit_file = os.path.join(FLAGS.output_dir, "submit_results.tsv")