diff --git a/tools/checkpoint_saver_megatron.py b/tools/checkpoint_saver_megatron.py index 47f1b6c666..403105f2f3 100644 --- a/tools/checkpoint_saver_megatron.py +++ b/tools/checkpoint_saver_megatron.py @@ -138,7 +138,7 @@ def check_message(msg): if hasattr (md, 'checkpoint_args'): # These are arguments that we are either changing, or cause problems for validation if they are set # Note that some of these deal with T5 so will need to be changed if we support T5. - args_to_keep = ['tensor_model_parallel_size', 'pipeline_model_parallel_size', 'params_dtype', + args_to_keep = ['tensor_model_parallel_size', 'pipeline_model_parallel_size', 'world_size', 'params_dtype', 'num_layers_per_virtual_pipeline_stage', 'virtual_pipeline_model_parallel_size', 'masked_softmax_fusion', 'bias_gelu_fusion', 'bias_dropout_fusion', 'sequence_parallel', 'async_tensor_model_parallel_allreduce',