Skip to content

Commit

Permalink
#40 Calling DistributedSampler.set_epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
corey-lambda committed Oct 21, 2024
1 parent c25b7c7 commit 574444a
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 0 deletions.
10 changes: 10 additions & 0 deletions 02-multi-gpu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,16 @@ As discussed before, this will let each rank grab a different subset of the data
)
```

You also need to call [DistributedSampler.set_epoch](https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler). Here's the quote from the pytorch doc on this:

```diff
+dataloader.sampler.set_epoch(state["epoch"])
batches = iter(dataloader)
```

> In distributed mode, calling the set_epoch() method at the beginning of each epoch before creating the DataLoader iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be always used.

### Only creating experiment directory on rank 0

Note the `dist.barrier()` calls before and after we create the directory. **These are very important!**
Expand Down
2 changes: 2 additions & 0 deletions 02-multi-gpu/train_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def _load_to_device(p):
if state["epoch_step"] > 0:
progress_bar.update(state["epoch_step"])

# We need to do this so we shuffle differently on each epoch in a reproducible way.
dataloader.sampler.set_epoch(state["epoch"])
batches = iter(dataloader)

for i_step in range(len(dataloader)):
Expand Down
1 change: 1 addition & 0 deletions 03-multi-node/train_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def _load_to_device(p):
if state["epoch_step"] > 0:
progress_bar.update(state["epoch_step"])

dataloader.sampler.set_epoch(state["epoch"])
batches = iter(dataloader)

for i_step in range(len(dataloader)):
Expand Down
1 change: 1 addition & 0 deletions 05-sharding-deepspeed/train_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def main():
if state["epoch_step"] > 0:
progress_bar.update(state["epoch_step"])

dataloader.sampler.set_epoch(state["epoch"])
batches = iter(dataloader)

for i_step in range(len(dataloader)):
Expand Down
1 change: 1 addition & 0 deletions 05-sharding-fsdp/train_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def safe_param_init_fn(module: torch.nn.Module):
if state["epoch_step"] > 0:
progress_bar.update(state["epoch_step"])

dataloader.sampler.set_epoch(state["epoch"])
batches = iter(dataloader)

for i_step in range(len(dataloader)):
Expand Down
1 change: 1 addition & 0 deletions 06-training-llama-405b/train_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def main():
if state["epoch_step"] > 0:
progress_bar.update(state["epoch_step"])

dataloader.sampler.set_epoch(state["epoch"])
batches = iter(dataloader)

for i_step in range(len(dataloader)):
Expand Down

0 comments on commit 574444a

Please sign in to comment.