Skip to content

Commit

Permalink
Merge branch 'fix-external-dataloader' into 'main'
Browse files Browse the repository at this point in the history
Fix wrapping of external dataloaders

See merge request ADLR/megatron-lm!2453
  • Loading branch information
ericharper committed Dec 12, 2024
2 parents fd69c2f + ebfc79b commit 99f23d2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
4 changes: 2 additions & 2 deletions megatron/core/rerun_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,8 +834,8 @@ class MyDataIterator:
replay_data_iterator = RerunDataIterator(data_iterator)
"""

def __init__(self, iterable: Any, make_iterable: bool = True) -> None:
self.iterable: Iterable[Any] = iter(iterable) if make_iterable else iterable
def __init__(self, iterable: Iterable[Any]) -> None:
self.iterable: Iterable[Any] = iterable
self.saved_microbatches: list[Any] = []
self.replaying: bool = False
self.replay_pos: int = 0
Expand Down
9 changes: 6 additions & 3 deletions megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,12 +1878,15 @@ def build_train_valid_test_data_iterators(
def _get_iterator(dataloader_type, dataloader):
"""Return dataset iterator."""
if dataloader_type == "single":
return RerunDataIterator(dataloader)
return RerunDataIterator(iter(dataloader))
elif dataloader_type == "cyclic":
return RerunDataIterator(cyclic_iter(dataloader))
return RerunDataIterator(iter(cyclic_iter(dataloader)))
elif dataloader_type == "external":
# External dataloader is passed through. User is expected to define how to iterate.
return RerunDataIterator(dataloader, make_iterable=False)
if isinstance(dataloader, list):
return [RerunDataIterator(d) for d in dataloader]
else:
return RerunDataIterator(dataloader)
else:
raise RuntimeError("unexpected dataloader type")

Expand Down

0 comments on commit 99f23d2

Please sign in to comment.