Skip to content

Commit

Permalink
ADLR/megatron-lm!2453 - Fix wrapping of external dataloaders
Browse files Browse the repository at this point in the history
Co-authored-by: Cyril Meurillon <[email protected]>
  • Loading branch information
2 people authored and ericharper committed Dec 12, 2024
1 parent d4e72c0 commit ebfc79b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
4 changes: 2 additions & 2 deletions megatron/core/rerun_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,8 +837,8 @@ class MyDataIterator:
replay_data_iterator = RerunDataIterator(data_iterator)
"""

def __init__(self, iterable: Any, make_iterable: bool = True) -> None:
self.iterable: Iterable[Any] = iter(iterable) if make_iterable else iterable
def __init__(self, iterable: Iterable[Any]) -> None:
self.iterable: Iterable[Any] = iterable
self.saved_microbatches: list[Any] = []
self.replaying: bool = False
self.replay_pos: int = 0
Expand Down
9 changes: 6 additions & 3 deletions megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,12 +1878,15 @@ def build_train_valid_test_data_iterators(
def _get_iterator(dataloader_type, dataloader):
"""Return dataset iterator."""
if dataloader_type == "single":
return RerunDataIterator(dataloader)
return RerunDataIterator(iter(dataloader))
elif dataloader_type == "cyclic":
return RerunDataIterator(cyclic_iter(dataloader))
return RerunDataIterator(iter(cyclic_iter(dataloader)))
elif dataloader_type == "external":
# External dataloader is passed through. User is expected to define how to iterate.
return RerunDataIterator(dataloader, make_iterable=False)
if isinstance(dataloader, list):
return [RerunDataIterator(d) for d in dataloader]
else:
return RerunDataIterator(dataloader)
else:
raise RuntimeError("unexpected dataloader type")

Expand Down

0 comments on commit ebfc79b

Please sign in to comment.