|
| 1 | +""" |
| 2 | +Util functions for setting up distributed training. |
| 3 | +Credit: https://github.com/mlfoundations/open_clip/blob/main/src/training/distributed.py |
| 4 | +""" |
| 5 | + |
| 6 | +import os |
| 7 | +import torch |
| 8 | + |
| 9 | +try: |
| 10 | + import horovod.torch as hvd |
| 11 | +except ImportError: |
| 12 | + hvd = None |
| 13 | + |
| 14 | + |
| 15 | +def is_global_master(args): |
| 16 | + return args.rank == 0 |
| 17 | + |
| 18 | + |
| 19 | +def is_local_master(args): |
| 20 | + return args.local_rank == 0 |
| 21 | + |
| 22 | + |
| 23 | +def is_master(args, local=False): |
| 24 | + return is_local_master(args) if local else is_global_master(args) |
| 25 | + |
| 26 | + |
| 27 | +def is_using_horovod(): |
| 28 | + # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set |
| 29 | + # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... |
| 30 | + ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] |
| 31 | + pmi_vars = ["PMI_RANK", "PMI_SIZE"] |
| 32 | + if all([var in os.environ for var in ompi_vars]) or all( |
| 33 | + [var in os.environ for var in pmi_vars] |
| 34 | + ): |
| 35 | + return True |
| 36 | + else: |
| 37 | + return False |
| 38 | + |
| 39 | + |
| 40 | +def is_using_distributed(): |
| 41 | + if "WORLD_SIZE" in os.environ: |
| 42 | + return int(os.environ["WORLD_SIZE"]) > 1 |
| 43 | + if "SLURM_NTASKS" in os.environ: |
| 44 | + return int(os.environ["SLURM_NTASKS"]) > 1 |
| 45 | + return False |
| 46 | + |
| 47 | + |
| 48 | +def world_info_from_env(): |
| 49 | + local_rank = 0 |
| 50 | + for v in ( |
| 51 | + "LOCAL_RANK", |
| 52 | + "MPI_LOCALRANKID", |
| 53 | + "SLURM_LOCALID", |
| 54 | + "OMPI_COMM_WORLD_LOCAL_RANK", |
| 55 | + ): |
| 56 | + if v in os.environ: |
| 57 | + local_rank = int(os.environ[v]) |
| 58 | + break |
| 59 | + global_rank = 0 |
| 60 | + for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"): |
| 61 | + if v in os.environ: |
| 62 | + global_rank = int(os.environ[v]) |
| 63 | + break |
| 64 | + world_size = 1 |
| 65 | + for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"): |
| 66 | + if v in os.environ: |
| 67 | + world_size = int(os.environ[v]) |
| 68 | + break |
| 69 | + |
| 70 | + return local_rank, global_rank, world_size |
| 71 | + |
| 72 | + |
| 73 | +def init_distributed_device(args): |
| 74 | + # Distributed training = training on more than one GPU. |
| 75 | + # Works in both single and multi-node scenarios. |
| 76 | + args.distributed = False |
| 77 | + args.world_size = 1 |
| 78 | + args.rank = 0 # global rank |
| 79 | + args.local_rank = 0 |
| 80 | + if args.horovod: |
| 81 | + assert hvd is not None, "Horovod is not installed" |
| 82 | + hvd.init() |
| 83 | + args.local_rank = int(hvd.local_rank()) |
| 84 | + args.rank = hvd.rank() |
| 85 | + args.world_size = hvd.size() |
| 86 | + args.distributed = True |
| 87 | + os.environ["LOCAL_RANK"] = str(args.local_rank) |
| 88 | + os.environ["RANK"] = str(args.rank) |
| 89 | + os.environ["WORLD_SIZE"] = str(args.world_size) |
| 90 | + elif is_using_distributed(): |
| 91 | + if "SLURM_PROCID" in os.environ: |
| 92 | + # DDP via SLURM |
| 93 | + args.local_rank, args.rank, args.world_size = world_info_from_env() |
| 94 | + # SLURM var -> torch.distributed vars in case needed |
| 95 | + os.environ["LOCAL_RANK"] = str(args.local_rank) |
| 96 | + os.environ["RANK"] = str(args.rank) |
| 97 | + os.environ["WORLD_SIZE"] = str(args.world_size) |
| 98 | + torch.distributed.init_process_group( |
| 99 | + backend=args.dist_backend, |
| 100 | + init_method=args.dist_url, |
| 101 | + world_size=args.world_size, |
| 102 | + rank=args.rank, |
| 103 | + ) |
| 104 | + else: |
| 105 | + # DDP via torchrun, torch.distributed.launch |
| 106 | + args.local_rank, _, _ = world_info_from_env() |
| 107 | + torch.distributed.init_process_group( |
| 108 | + backend=args.dist_backend, init_method=args.dist_url |
| 109 | + ) |
| 110 | + args.world_size = torch.distributed.get_world_size() |
| 111 | + args.rank = torch.distributed.get_rank() |
| 112 | + args.distributed = True |
| 113 | + else: |
| 114 | + # needed to run on single gpu |
| 115 | + torch.distributed.init_process_group( |
| 116 | + backend=args.dist_backend, |
| 117 | + init_method=args.dist_url, |
| 118 | + world_size=1, |
| 119 | + rank=0, |
| 120 | + ) |
| 121 | + |
| 122 | + if torch.cuda.is_available(): |
| 123 | + if args.distributed and not args.no_set_device_rank: |
| 124 | + device = "cuda:%d" % args.local_rank |
| 125 | + else: |
| 126 | + device = "cuda:0" |
| 127 | + torch.cuda.set_device(device) |
| 128 | + else: |
| 129 | + device = "cpu" |
| 130 | + args.device = device |
| 131 | + device = torch.device(device) |
| 132 | + return device |
0 commit comments