Skip to content

Commit 8ff2e4b

Browse files
authored
add back num_workers and worldsize dataloader resumption checks (#1250)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Dataloader checkpoints now capture worker and world-size settings to ensure compatible restores. * On restore, mismatches in workers or world size trigger a warning and safely restart the dataloader from scratch; compatible states restore with a success message. * **Tests** * Added tests validating warning logs and restart behavior when worker or world-size settings differ from the saved checkpoint, and successful restore when they match. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Peter St. John <[email protected]>
1 parent 4388d0f commit 8ff2e4b

File tree

2 files changed

+136
-0
lines changed

2 files changed

+136
-0
lines changed

bionemo-recipes/recipes/esm2_native_te/checkpoint.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,8 @@ def save_dataloader(
515515
dataloader_path = ckpt_path / f"dataloader_rank_{dist_config.rank}.pt"
516516

517517
dataloader_state = dataloader.state_dict()
518+
dataloader_state["num_workers"] = dataloader.num_workers
519+
dataloader_state["num_ranks"] = dist_config.world_size
518520
torch.save(dataloader_state, dataloader_path)
519521
if dist_config.is_main_process():
520522
logger.info(f"Saved dataloader state to {dataloader_path}")
@@ -545,6 +547,18 @@ def load_dataloader(
545547
return dataloader
546548

547549
dataloader_state = torch.load(dataloader_path)
550+
551+
if (
552+
dataloader.num_workers != dataloader_state["num_workers"]
553+
or dist_config.world_size != dataloader_state["num_ranks"]
554+
):
555+
logger.warning(
556+
f"Dataloader num_workers mismatch: {dataloader.num_workers} != {dataloader_state['num_workers']} or "
557+
f"num_ranks mismatch: {dist_config.world_size} != {dataloader_state['num_ranks']}, "
558+
"starting dataloader from scratch."
559+
)
560+
return dataloader
561+
548562
dataloader.load_state_dict(dataloader_state)
549563
if dist_config.is_main_process():
550564
logger.info(f"Loaded dataloader state from {dataloader_path}")

bionemo-recipes/recipes/esm2_native_te/tests/test_dataset.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import logging
1617
import os
1718
import shutil
1819
from dataclasses import dataclass
@@ -711,3 +712,124 @@ def test_lazy_tokenization_returns_batch():
711712

712713
batch = next(iter(dataloader))
713714
assert batch is not None
715+
716+
717+
def test_stateful_dataloader_load_fails_if_num_workers_mismatch(tmp_path, caplog):
718+
dataloader_path = tmp_path / "dl_test_num_workers_mismatch"
719+
shutil.rmtree(dataloader_path, ignore_errors=True)
720+
os.makedirs(dataloader_path, exist_ok=True)
721+
tokenizer_name = "facebook/esm2_t6_8M_UR50D"
722+
load_dataset_kwargs = {
723+
"path": "parquet",
724+
"split": "train",
725+
"data_files": "train.parquet",
726+
"streaming": False,
727+
}
728+
729+
rank0_dist_config = MockDistributedConfig(
730+
rank=0,
731+
local_rank=0,
732+
world_size=1,
733+
)
734+
735+
reference_dataloader, _ = create_dataloader(
736+
distributed_config=rank0_dist_config,
737+
tokenizer_name=tokenizer_name,
738+
load_dataset_kwargs=load_dataset_kwargs,
739+
micro_batch_size=4,
740+
num_workers=1,
741+
mlm_probability=0,
742+
)
743+
744+
save_dataloader(
745+
dataloader=reference_dataloader,
746+
ckpt_path=dataloader_path,
747+
dist_config=rank0_dist_config,
748+
)
749+
750+
del reference_dataloader
751+
752+
reference_dataloader, _ = create_dataloader(
753+
distributed_config=rank0_dist_config,
754+
tokenizer_name=tokenizer_name,
755+
load_dataset_kwargs=load_dataset_kwargs,
756+
micro_batch_size=4,
757+
num_workers=2,
758+
mlm_probability=0,
759+
)
760+
761+
with caplog.at_level(logging.WARNING):
762+
load_dataloader(
763+
dataloader=reference_dataloader,
764+
ckpt_path=dataloader_path,
765+
dist_config=rank0_dist_config,
766+
)
767+
768+
assert (
769+
"Dataloader num_workers mismatch: 2 != 1 or num_ranks mismatch: 1 != 1, starting dataloader from scratch."
770+
in caplog.text
771+
)
772+
773+
774+
def test_stateful_dataloader_load_fails_if_num_ranks_mismatch(tmp_path, caplog):
775+
dataloader_path = tmp_path / "dl_test_num_workers_mismatch"
776+
shutil.rmtree(dataloader_path, ignore_errors=True)
777+
os.makedirs(dataloader_path, exist_ok=True)
778+
tokenizer_name = "facebook/esm2_t6_8M_UR50D"
779+
load_dataset_kwargs = {
780+
"path": "parquet",
781+
"split": "train",
782+
"data_files": "train.parquet",
783+
"streaming": False,
784+
}
785+
786+
rank0_dist_config = MockDistributedConfig(
787+
rank=0,
788+
local_rank=0,
789+
world_size=1,
790+
)
791+
792+
reference_dataloader, _ = create_dataloader(
793+
distributed_config=rank0_dist_config,
794+
tokenizer_name=tokenizer_name,
795+
load_dataset_kwargs=load_dataset_kwargs,
796+
micro_batch_size=4,
797+
num_workers=1,
798+
mlm_probability=0,
799+
)
800+
801+
save_dataloader(
802+
dataloader=reference_dataloader,
803+
ckpt_path=dataloader_path,
804+
dist_config=rank0_dist_config,
805+
)
806+
807+
del reference_dataloader
808+
del rank0_dist_config
809+
810+
rank2_dist_config = MockDistributedConfig(
811+
rank=0,
812+
local_rank=0,
813+
world_size=2,
814+
)
815+
816+
reference_dataloader, _ = create_dataloader(
817+
distributed_config=rank2_dist_config,
818+
tokenizer_name=tokenizer_name,
819+
load_dataset_kwargs=load_dataset_kwargs,
820+
micro_batch_size=4,
821+
num_workers=1,
822+
mlm_probability=0,
823+
)
824+
825+
with caplog.at_level(logging.WARNING):
826+
load_dataloader(
827+
dataloader=reference_dataloader,
828+
ckpt_path=dataloader_path,
829+
dist_config=rank2_dist_config,
830+
)
831+
832+
assert (
833+
"Dataloader num_workers mismatch: 1 != 1 or num_ranks mismatch: 2 != 1, starting dataloader from scratch."
834+
in caplog.text
835+
)

0 commit comments

Comments
 (0)