Skip to content

Commit

Permalink
Update DDP tutorial for the correct order of set_device (#1285)
Browse files Browse the repository at this point in the history
  • Loading branch information
fegin authored Sep 17, 2024
1 parent 26de419 commit a308b4e
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions distributed/ddp-tutorial-series/multigpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def ddp_setup(rank, world_size):
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
init_process_group(backend="nccl", rank=rank, world_size=world_size)

class Trainer:
def __init__(
Expand Down Expand Up @@ -99,6 +99,6 @@ def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_s
parser.add_argument('save_every', type=int, help='How often to save a snapshot')
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
args = parser.parse_args()

world_size = torch.cuda.device_count()
mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size)
4 changes: 2 additions & 2 deletions distributed/ddp-tutorial-series/multigpu_torchrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@


def ddp_setup():
init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
init_process_group(backend="nccl")

class Trainer:
def __init__(
Expand Down Expand Up @@ -107,5 +107,5 @@ def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str
parser.add_argument('save_every', type=int, help='How often to save a snapshot')
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
args = parser.parse_args()

main(args.save_every, args.total_epochs, args.batch_size)
4 changes: 2 additions & 2 deletions distributed/ddp-tutorial-series/multinode.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@


def ddp_setup():
init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
init_process_group(backend="nccl")

class Trainer:
def __init__(
Expand Down Expand Up @@ -108,5 +108,5 @@ def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str
parser.add_argument('save_every', type=int, help='How often to save a snapshot')
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
args = parser.parse_args()

main(args.save_every, args.total_epochs, args.batch_size)

0 comments on commit a308b4e

Please sign in to comment.