Skip to content

Commit

Permalink
sts-b support added
Browse files Browse the repository at this point in the history
  • Loading branch information
kamalkraj committed Dec 9, 2019
1 parent 5a620f3 commit 8d0cc21
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 14 deletions.
20 changes: 15 additions & 5 deletions classifier_data_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,9 +659,10 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
label_id=0,
is_real_example=False)

label_map = {}
for (i, label) in enumerate(label_list):
label_map[label] = i
if FLAGS.classification_task_name.lower() != "sts":
label_map = {}
for (i, label) in enumerate(label_list):
label_map[label] = i

tokens_a = tokenizer.tokenize(example.text_a)
tokens_b = None
Expand Down Expand Up @@ -729,7 +730,11 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length

label_id = label_map[example.label]
if FLAGS.classification_task_name.lower() != "sts":
label_id = label_map[example.label]
else:
label_id = example.label

if ex_index < 5:
logging.info("*** Example ***")
logging.info("guid: %s" % (example.guid))
Expand Down Expand Up @@ -766,12 +771,17 @@ def file_based_convert_examples_to_features(examples, label_list,
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f

def create_float_feature(values):
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return f

features = collections.OrderedDict()
features["input_ids"] = create_int_feature(feature.input_ids)
features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids)
features["label_ids"] = create_int_feature([feature.label_id])
features["label_ids"] = create_float_feature([feature.label_id])\
if FLAGS.classification_task_name.lower() == "sts" else create_int_feature([feature.label_id])
features["is_real_example"] = create_int_feature(
[int(feature.is_real_example)])

Expand Down
2 changes: 1 addition & 1 deletion input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def create_classifier_dataset(file_path,
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], tf.int64),
'label_ids': tf.io.FixedLenFeature([], tf.int64) if FLAGS.task_name.lower() != "sts" else tf.io.FixedLenFeature([], tf.float32),
'is_real_example': tf.io.FixedLenFeature([], tf.int64),
}
input_fn = file_based_input_fn_builder(file_path, name_to_features)
Expand Down
42 changes: 34 additions & 8 deletions run_classifer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,19 @@ def classification_loss_fn(labels, logits):

return classification_loss_fn

def get_loss_fn_v2(loss_factor=1.0):
"""Gets the loss function for STS."""

def sts_loss_fn(labels, logits):
"""STS loss"""
logits = tf.squeeze(logits, [-1])
per_example_loss = tf.square(logits - labels)
loss = tf.reduce_mean(per_example_loss)
loss *= loss_factor
return loss

return sts_loss_fn

def get_model(albert_config, max_seq_length, num_labels, init_checkpoint, learning_rate,
num_train_steps, num_warmup_steps,loss_multiplier):
"""Returns keras fuctional model"""
Expand Down Expand Up @@ -217,10 +230,13 @@ def get_model(albert_config, max_seq_length, num_labels, init_checkpoint, learni
beta_2=0.999,
epsilon=FLAGS.adam_epsilon,
exclude_from_weight_decay=['layer_norm', 'bias'])

loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model.compile(optimizer=optimizer,loss=loss_fct,metrics=['accuracy'])

if FLAGS.task_name.lower() == 'sts':
loss_fct = tf.keras.losses.MeanSquaredError()
model.compile(optimizer=optimizer,loss=loss_fct,metrics=['mse'])
else:
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer=optimizer,loss=loss_fct,metrics=['accuracy'])

return model

Expand Down Expand Up @@ -316,10 +332,16 @@ def main(_):
custom_callbacks = [summary_callback, checkpoint_callback]

def metric_fn():
return tf.keras.metrics.SparseCategoricalAccuracy(dtype=tf.float32)
if FLAGS.task_name.lower() == "sts":
return tf.keras.metrics.MeanSquaredError(dtype=tf.float32)
else:
return tf.keras.metrics.SparseCategoricalAccuracy(dtype=tf.float32)

if FLAGS.custom_training_loop:
loss_fn = get_loss_fn(num_labels,loss_factor=loss_multiplier)
if FLAGS.task_name.lower() == "sts":
loss_fn = get_loss_fn_v2(loss_factor=loss_multiplier)
else:
loss_fn = get_loss_fn(num_labels,loss_factor=loss_multiplier)
model = run_customized_training_loop(strategy = strategy,
model = model,
loss_fn = loss_fn,
Expand Down Expand Up @@ -396,8 +418,12 @@ def metric_fn():

with strategy.scope():
logits = model.predict(prediction_dataset)
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
probabilities = tf.nn.softmax(logits, axis=-1)
if FLAGS.task_name.lower() == "sts":
predictions = logits
probabilities = logits
else:
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
probabilities = tf.nn.softmax(logits, axis=-1)

output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
output_submit_file = os.path.join(FLAGS.output_dir, "submit_results.tsv")
Expand Down

0 comments on commit 8d0cc21

Please sign in to comment.