From 16ca431ad3b48bdf6f6ad8b9ae36d24b188670df Mon Sep 17 00:00:00 2001 From: Jeongseok Kang Date: Wed, 5 Jun 2024 09:41:50 +0900 Subject: [PATCH 1/2] Add `distribution_strategy` and `all_reduce_alg` flags to TensorFlow BERT pretraining --- .../tensorflow/bert/run_pretraining.py | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/language_model/tensorflow/bert/run_pretraining.py b/language_model/tensorflow/bert/run_pretraining.py index 7de5514e4..993e09d4c 100644 --- a/language_model/tensorflow/bert/run_pretraining.py +++ b/language_model/tensorflow/bert/run_pretraining.py @@ -117,6 +117,28 @@ flags.DEFINE_integer("keep_checkpoint_max", 5, "The maximum number of checkpoints to keep.") +flags.DEFINE_string( + name="distribution_strategy", short_name="ds", default="mirrored", + help="The Distribution Strategy to use for training. " + "Accepted values are 'off', 'one_device', " + "'mirrored', 'parameter_server', 'collective', " + "case insensitive. 'off' means not to use " + "Distribution Strategy; 'default' means to choose " + "from `MirroredStrategy` or `OneDeviceStrategy` " + "according to the number of GPUs.") + +flags.DEFINE_string( + name="all_reduce_alg", short_name="ara", default="nccl", + help="Defines the algorithm to use for performing all-reduce." + "When specified with MirroredStrategy for single " + "worker, this controls " + "tf.contrib.distribute.AllReduceCrossTowerOps. When " + "specified with MultiWorkerMirroredStrategy, this " + "controls " + "tf.distribute.experimental.CollectiveCommunication; " + "valid options are `ring` and `nccl`.") + + def model_fn_builder(bert_config, init_checkpoint, learning_rate, num_train_steps, num_warmup_steps, use_tpu, use_one_hot_embeddings, optimizer, poly_power, @@ -542,9 +564,9 @@ def main(_): allow_soft_placement=True) distribution_strategy = distribution_utils.get_distribution_strategy( - distribution_strategy="mirrored", + distribution_strategy=flags.distribution_strategy, num_gpus=FLAGS.num_gpus, - all_reduce_alg="nccl", + all_reduce_alg=flags.all_reduce_alg, num_packs=0) dist_gpu_config = tf.estimator.RunConfig( From 77934d072d750a0a3c6a8125d86131ae76553f33 Mon Sep 17 00:00:00 2001 From: Jeongseok Kang Date: Wed, 5 Jun 2024 13:21:16 +0900 Subject: [PATCH 2/2] fix: Use 'FLAGS' instead of 'flags' --- language_model/tensorflow/bert/run_pretraining.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/language_model/tensorflow/bert/run_pretraining.py b/language_model/tensorflow/bert/run_pretraining.py index 993e09d4c..a7edd376f 100644 --- a/language_model/tensorflow/bert/run_pretraining.py +++ b/language_model/tensorflow/bert/run_pretraining.py @@ -564,9 +564,9 @@ def main(_): allow_soft_placement=True) distribution_strategy = distribution_utils.get_distribution_strategy( - distribution_strategy=flags.distribution_strategy, + distribution_strategy=FLAGS.distribution_strategy, num_gpus=FLAGS.num_gpus, - all_reduce_alg=flags.all_reduce_alg, + all_reduce_alg=FLAGS.all_reduce_alg, num_packs=0) dist_gpu_config = tf.estimator.RunConfig(