diff --git a/megatron/core/rerun_state_machine.py b/megatron/core/rerun_state_machine.py index 62e1d95475..cb948a318b 100644 --- a/megatron/core/rerun_state_machine.py +++ b/megatron/core/rerun_state_machine.py @@ -834,8 +834,8 @@ class MyDataIterator: replay_data_iterator = RerunDataIterator(data_iterator) """ - def __init__(self, iterable: Any, make_iterable: bool = True) -> None: - self.iterable: Iterable[Any] = iter(iterable) if make_iterable else iterable + def __init__(self, iterable: Iterable[Any]) -> None: + self.iterable: Iterable[Any] = iterable self.saved_microbatches: list[Any] = [] self.replaying: bool = False self.replay_pos: int = 0 diff --git a/megatron/training/training.py b/megatron/training/training.py index 741a8bf0a6..401d404d1d 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1878,12 +1878,15 @@ def build_train_valid_test_data_iterators( def _get_iterator(dataloader_type, dataloader): """Return dataset iterator.""" if dataloader_type == "single": - return RerunDataIterator(dataloader) + return RerunDataIterator(iter(dataloader)) elif dataloader_type == "cyclic": - return RerunDataIterator(cyclic_iter(dataloader)) + return RerunDataIterator(iter(cyclic_iter(dataloader))) elif dataloader_type == "external": # External dataloader is passed through. User is expected to define how to iterate. - return RerunDataIterator(dataloader, make_iterable=False) + if isinstance(dataloader, list): + return [RerunDataIterator(d) for d in dataloader] + else: + return RerunDataIterator(dataloader) else: raise RuntimeError("unexpected dataloader type")