diff --git a/02-multi-gpu/train_llm.py b/02-multi-gpu/train_llm.py index 865325e..8d73013 100644 --- a/02-multi-gpu/train_llm.py +++ b/02-multi-gpu/train_llm.py @@ -14,6 +14,7 @@ from torch.nn.parallel import DistributedDataParallel from torch import distributed as dist from torch.distributed.elastic.multiprocessing.errors import record +from torch.distributed.optim import ZeroRedundancyOptimizer import wandb import tqdm @@ -79,7 +80,9 @@ def main(): ) LOGGER.info(f"{len(dataloader)} batches per epoch") - optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) + optimizer = ZeroRedundancyOptimizer( + model.parameters(), optimizer_class=torch.optim.AdamW, lr=args.lr + ) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=1000, eta_min=args.lr * 1e-2 ) @@ -191,6 +194,8 @@ def _load_to_device(p): t.reset() if state["global_step"] % args.ckpt_freq == 0: + # Have to pull the shards to rank 0 before we save the state dict + ZeroRedundancyOptimizer.consolidate_state_dict(optimizer, to=0) if rank == 0: LOGGER.info("Saving checkpoint.") torch.save(optimizer.state_dict(), exp_dir / "optimizer.pt") diff --git a/03-multi-node/train_llm.py b/03-multi-node/train_llm.py index aeb6fc5..e28aaae 100644 --- a/03-multi-node/train_llm.py +++ b/03-multi-node/train_llm.py @@ -14,6 +14,7 @@ from torch.nn.parallel import DistributedDataParallel from torch import distributed as dist from torch.distributed.elastic.multiprocessing.errors import record +from torch.distributed.optim import ZeroRedundancyOptimizer import wandb import tqdm @@ -80,7 +81,9 @@ def main(): ) LOGGER.info(f"{len(dataloader)} batches per epoch") - optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) + optimizer = ZeroRedundancyOptimizer( + model.parameters(), optimizer_class=torch.optim.AdamW, lr=args.lr + ) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=1000, eta_min=args.lr * 1e-2 ) @@ -193,6 +196,7 @@ def _load_to_device(p): t.reset() if state["global_step"] % args.ckpt_freq == 0: + ZeroRedundancyOptimizer.consolidate_state_dict(optimizer, to=0) if rank == 0: LOGGER.info("Saving checkpoint.") torch.save(optimizer.state_dict(), exp_dir / "optimizer.pt")