From ebfc79b632393b7729e7bc0dff5809b0c453621f Mon Sep 17 00:00:00 2001 From: Cyril Meurillon Date: Wed, 11 Dec 2024 20:38:41 -0800 Subject: [PATCH] ADLR/megatron-lm!2453 - Fix wrapping of external dataloaders Co-authored-by: Cyril Meurillon --- megatron/core/rerun_state_machine.py | 4 ++-- megatron/training/training.py | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/megatron/core/rerun_state_machine.py b/megatron/core/rerun_state_machine.py index 22b13b0c9e..3485f90690 100644 --- a/megatron/core/rerun_state_machine.py +++ b/megatron/core/rerun_state_machine.py @@ -837,8 +837,8 @@ class MyDataIterator: replay_data_iterator = RerunDataIterator(data_iterator) """ - def __init__(self, iterable: Any, make_iterable: bool = True) -> None: - self.iterable: Iterable[Any] = iter(iterable) if make_iterable else iterable + def __init__(self, iterable: Iterable[Any]) -> None: + self.iterable: Iterable[Any] = iterable self.saved_microbatches: list[Any] = [] self.replaying: bool = False self.replay_pos: int = 0 diff --git a/megatron/training/training.py b/megatron/training/training.py index 741a8bf0a6..401d404d1d 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1878,12 +1878,15 @@ def build_train_valid_test_data_iterators( def _get_iterator(dataloader_type, dataloader): """Return dataset iterator.""" if dataloader_type == "single": - return RerunDataIterator(dataloader) + return RerunDataIterator(iter(dataloader)) elif dataloader_type == "cyclic": - return RerunDataIterator(cyclic_iter(dataloader)) + return RerunDataIterator(iter(cyclic_iter(dataloader))) elif dataloader_type == "external": # External dataloader is passed through. User is expected to define how to iterate. - return RerunDataIterator(dataloader, make_iterable=False) + if isinstance(dataloader, list): + return [RerunDataIterator(d) for d in dataloader] + else: + return RerunDataIterator(dataloader) else: raise RuntimeError("unexpected dataloader type")