From 281644543097e9089e0e9b4b264c4a0a91877dca Mon Sep 17 00:00:00 2001 From: Cyril Meurillon Date: Wed, 11 Dec 2024 18:15:37 -0800 Subject: [PATCH] ADLR/megatron-lm!2443 - Fix assert warning in !2282 Co-authored-by: Cyril Meurillon --- megatron/core/rerun_state_machine.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/megatron/core/rerun_state_machine.py b/megatron/core/rerun_state_machine.py index 22b13b0c9e..62e1d95475 100644 --- a/megatron/core/rerun_state_machine.py +++ b/megatron/core/rerun_state_machine.py @@ -252,8 +252,7 @@ def train_step(data_iterator, ...): for d in data_iterators: assert ( isinstance(d, RerunDataIterator), - "data iterator is not wrapped with RerunDataIterator", - ) + ), "data iterator is not wrapped with RerunDataIterator" # Are we about to start the initial run? if self.state == RerunState.NOT_RUNNING_YET: @@ -263,8 +262,7 @@ def train_step(data_iterator, ...): if self.data_iterator_checkpoints is not None: assert ( len(self.data_iterator_checkpoints) == len(data_iterators), - "data_iterator has different length than checkpointed data iterator", - ) + ), "data iterator has different length than checkpointed data iterator" for i, d in enumerate(data_iterators): d.set_checkpoint_state(self.data_iterator_checkpoints[i]) self.data_iterator_checkpoints = None @@ -667,8 +665,7 @@ def save_my_model_checkpoint(data_iterator, ...): for d in data_iterators: assert ( isinstance(d, RerunDataIterator), - "data iterator is not wrapped with RerunDataIterator", - ) + ), "data iterator is not wrapped with RerunDataIterator" state: dict[str, Any] = { 'mode': self.mode,