Skip to content

Commit b3e99eb

Browse files
committed
Add distributed checkpointing tests and fix pin_memory compatibility
- Add comprehensive distributed checkpointing tests (8 tests total) - Single and multi-GPU checkpoint save/resume for DDP and FSDP2 - Final model save tests for inference export - Scheduler resume tests - Disable pin_memory in dataloader due to PyTorch 2.9/torchdata 0.11 incompatibility - Add checkpoint verification to multi-GPU tests - Improve test documentation and docstrings - Add wandb project config field to avoid hydra struct errors Signed-off-by: Savitha Srinivasan <[email protected]>
1 parent 6aa1620 commit b3e99eb

File tree

5 files changed

+757
-24
lines changed

5 files changed

+757
-24
lines changed

bionemo-recipes/recipes/llama3/dataset.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,15 @@
1414
# limitations under the License.
1515

1616
import logging
17-
from pathlib import Path
1817

1918
import datasets
2019
import datasets.distributed
20+
from distributed_config import DistributedConfig
2121
from torch.utils.data import DistributedSampler
2222
from torchdata.stateful_dataloader import StatefulDataLoader
2323
from transformers import AutoTokenizer
2424
from transformers.data.data_collator import DataCollatorForLanguageModeling
2525

26-
from distributed_config import DistributedConfig
27-
2826

2927
logger = logging.getLogger(__name__)
3028

@@ -39,7 +37,7 @@ def create_tokenized_dataset(
3937
use_lazy_tokenization: bool = True,
4038
):
4139
"""Create a tokenized dataset with windowing.
42-
40+
4341
Args:
4442
distributed_config: The distributed configuration.
4543
tokenizer_path: Path to the nucleotide tokenizer directory.
@@ -48,7 +46,7 @@ def create_tokenized_dataset(
4846
stride: The stride for windowing (overlap = stride tokens).
4947
buffer_size: The buffer size for shuffle.
5048
use_lazy_tokenization: Whether to use datasets.set_transform for tokenization.
51-
49+
5250
Returns:
5351
Tuple of (tokenized_dataset, tokenizer).
5452
"""
@@ -61,8 +59,10 @@ def create_tokenized_dataset(
6159
if "train" in dataset:
6260
dataset = dataset["train"]
6361
else:
64-
raise ValueError(f"Dataset has splits {list(dataset.keys())} but no 'train' split found. "
65-
"Please specify split='train' in load_dataset_kwargs or ensure your dataset has a 'train' split.")
62+
raise ValueError(
63+
f"Dataset has splits {list(dataset.keys())} but no 'train' split found. "
64+
"Please specify split='train' in load_dataset_kwargs or ensure your dataset has a 'train' split."
65+
)
6666

6767
# Normalize column names - rename 'nt_sequence' to 'sequence' if present
6868
# Only do this for non-streaming datasets (streaming datasets don't have column_names attribute)
@@ -120,7 +120,7 @@ def create_bshd_dataloader(
120120
use_lazy_tokenization: bool = True,
121121
):
122122
"""Create a BSHD dataloader for genomic sequences using CLM (causal language modeling).
123-
123+
124124
Args:
125125
distributed_config: The distributed configuration.
126126
tokenizer_path: Path to the nucleotide tokenizer directory.
@@ -132,7 +132,7 @@ def create_bshd_dataloader(
132132
seed: The seed to use for the distributed sampler and data collator.
133133
buffer_size: The buffer size for shuffle.
134134
use_lazy_tokenization: Whether to use datasets.set_transform for tokenization.
135-
135+
136136
Returns:
137137
A tuple of (dataloader, dataset_or_sampler).
138138
"""
@@ -168,9 +168,8 @@ def create_bshd_dataloader(
168168
batch_size=micro_batch_size,
169169
collate_fn=data_collator,
170170
num_workers=num_workers,
171-
pin_memory=True,
171+
pin_memory=False, # Disabled due to PyTorch 2.9 compatibility issue with torchdata 0.11.0
172172
persistent_workers=num_workers > 0,
173173
)
174174

175175
return train_dataloader, tokenized_dataset if sampler is None else sampler
176-

bionemo-recipes/recipes/llama3/hydra_config/L0_sanity.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ dataset:
2727
wandb_init_args:
2828
name: "llama3_8B_genomic_sanity"
2929
mode: "offline"
30+
project: null # Set to null by default, override with +wandb_init_args.project=your-project
3031

3132
# Learning rate scheduler config
3233
lr_scheduler_kwargs:
@@ -41,4 +42,3 @@ checkpoint:
4142

4243
logger:
4344
frequency: 1
44-

bionemo-recipes/recipes/llama3/hydra_config/defaults.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ dataset:
2626
# WandB config
2727
wandb_init_args:
2828
name: ???
29+
project: null # Optional: set to your wandb project name
2930

3031
# mFSDP config
3132
fully_shard_kwargs:
@@ -73,6 +74,3 @@ checkpoint:
7374

7475
logger:
7576
frequency: 100
76-
77-
78-

0 commit comments

Comments
 (0)