Skip to content

Commit

Permalink
#44 Adding ZeroRedundancyOptimizer to ch 2,3
Browse files Browse the repository at this point in the history
  • Loading branch information
corey-lambda committed Oct 21, 2024
1 parent 574444a commit 2c7401e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
7 changes: 6 additions & 1 deletion 02-multi-gpu/train_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch.nn.parallel import DistributedDataParallel
from torch import distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.optim import ZeroRedundancyOptimizer

import wandb
import tqdm
Expand Down Expand Up @@ -79,7 +80,9 @@ def main():
)
LOGGER.info(f"{len(dataloader)} batches per epoch")

optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
optimizer = ZeroRedundancyOptimizer(
model.parameters(), optimizer_class=torch.optim.AdamW, lr=args.lr
)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=1000, eta_min=args.lr * 1e-2
)
Expand Down Expand Up @@ -191,6 +194,8 @@ def _load_to_device(p):
t.reset()

if state["global_step"] % args.ckpt_freq == 0:
# Have to pull the shards to rank 0 before we save the state dict
ZeroRedundancyOptimizer.consolidate_state_dict(optimizer, to=0)
if rank == 0:
LOGGER.info("Saving checkpoint.")
torch.save(optimizer.state_dict(), exp_dir / "optimizer.pt")
Expand Down
6 changes: 5 additions & 1 deletion 03-multi-node/train_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch.nn.parallel import DistributedDataParallel
from torch import distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.optim import ZeroRedundancyOptimizer

import wandb
import tqdm
Expand Down Expand Up @@ -80,7 +81,9 @@ def main():
)
LOGGER.info(f"{len(dataloader)} batches per epoch")

optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
optimizer = ZeroRedundancyOptimizer(
model.parameters(), optimizer_class=torch.optim.AdamW, lr=args.lr
)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=1000, eta_min=args.lr * 1e-2
)
Expand Down Expand Up @@ -193,6 +196,7 @@ def _load_to_device(p):
t.reset()

if state["global_step"] % args.ckpt_freq == 0:
ZeroRedundancyOptimizer.consolidate_state_dict(optimizer, to=0)
if rank == 0:
LOGGER.info("Saving checkpoint.")
torch.save(optimizer.state_dict(), exp_dir / "optimizer.pt")
Expand Down

0 comments on commit 2c7401e

Please sign in to comment.