diff --git a/tools/train.py b/tools/train.py index 29a88bde9..7aed737ee 100644 --- a/tools/train.py +++ b/tools/train.py @@ -33,7 +33,7 @@ def parse_config(): parser.add_argument('--sync_bn', action='store_true', default=False, help='whether to use sync bn') parser.add_argument('--fix_random_seed', action='store_true', default=False, help='') parser.add_argument('--ckpt_save_interval', type=int, default=1, help='number of training epochs') - parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training') + parser.add_argument('--local_rank', type=int, default=None, help='local rank for distributed training') parser.add_argument('--max_ckpt_save_num', type=int, default=30, help='max number of saved checkpoint') parser.add_argument('--merge_all_iters_to_one_epoch', action='store_true', default=False, help='') parser.add_argument('--set', dest='set_cfgs', default=None, nargs=argparse.REMAINDER, @@ -71,6 +71,9 @@ def main(): dist_train = False total_gpus = 1 else: + if args.local_rank is None: + args.local_rank = int(os.environ.get('LOCAL_RANK', '0')) + total_gpus, cfg.LOCAL_RANK = getattr(common_utils, 'init_dist_%s' % args.launcher)( args.tcp_port, args.local_rank, backend='nccl' )