Skip to content

Commit 3562f19

Browse files
committed
Dataset class and tests for LLAMA3; support for streaming, parquet
Signed-off-by: savitha-eng <[email protected]>
1 parent eb9dfc4 commit 3562f19

File tree

4 files changed

+625
-0
lines changed

4 files changed

+625
-0
lines changed
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
from pathlib import Path
18+
19+
import datasets
20+
import datasets.distributed
21+
from torch.utils.data import DistributedSampler
22+
from torchdata.stateful_dataloader import StatefulDataLoader
23+
from transformers import AutoTokenizer
24+
from transformers.data.data_collator import DataCollatorForLanguageModeling
25+
26+
from distributed_config import DistributedConfig
27+
28+
29+
logger = logging.getLogger(__name__)
30+
31+
32+
def create_tokenized_dataset(
33+
distributed_config: DistributedConfig,
34+
tokenizer_path: str,
35+
load_dataset_kwargs: dict,
36+
max_seq_length: int = 8192,
37+
stride: int = 200,
38+
buffer_size: int = 500_000,
39+
use_lazy_tokenization: bool = True,
40+
):
41+
"""Create a tokenized dataset with windowing.
42+
43+
Args:
44+
distributed_config: The distributed configuration.
45+
tokenizer_path: Path to the nucleotide tokenizer directory.
46+
load_dataset_kwargs: Keyword arguments to pass to `load_dataset`.
47+
max_seq_length: The maximum length of sequences (window size).
48+
stride: The stride for windowing (overlap = stride tokens).
49+
buffer_size: The buffer size for shuffle.
50+
use_lazy_tokenization: Whether to use datasets.set_transform for tokenization.
51+
52+
Returns:
53+
Tuple of (tokenized_dataset, tokenizer).
54+
"""
55+
logger.info(f"Loading dataset with kwargs: {load_dataset_kwargs}")
56+
dataset = datasets.load_dataset(**load_dataset_kwargs)
57+
logger.info(f"Loaded dataset: {dataset}")
58+
59+
# Handle DatasetDict (extract "train" split if present)
60+
if isinstance(dataset, (datasets.DatasetDict, datasets.IterableDatasetDict)):
61+
if "train" in dataset:
62+
dataset = dataset["train"]
63+
else:
64+
raise ValueError(f"Dataset has splits {list(dataset.keys())} but no 'train' split found. "
65+
"Please specify split='train' in load_dataset_kwargs or ensure your dataset has a 'train' split.")
66+
67+
# Normalize column names - rename 'nt_sequence' to 'sequence' if present
68+
# Only do this for non-streaming datasets (streaming datasets don't have column_names attribute)
69+
if hasattr(dataset, "column_names") and dataset.column_names is not None:
70+
if "nt_sequence" in dataset.column_names and "sequence" not in dataset.column_names:
71+
logger.info("Renaming column 'nt_sequence' to 'sequence' for consistency")
72+
dataset = dataset.rename_column("nt_sequence", "sequence")
73+
74+
if isinstance(dataset, datasets.IterableDataset):
75+
dataset = datasets.distributed.split_dataset_by_node(
76+
dataset,
77+
rank=distributed_config.rank,
78+
world_size=distributed_config.world_size,
79+
)
80+
dataset = dataset.shuffle(seed=42, buffer_size=buffer_size)
81+
82+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
83+
84+
def tokenize_with_windowing(examples):
85+
"""Tokenize nucleotide sequences with windowing (one-to-many mapping)."""
86+
# Tokenize with windowing using return_overflowing_tokens
87+
result = tokenizer(
88+
examples["sequence"],
89+
max_length=max_seq_length,
90+
stride=stride,
91+
truncation=True,
92+
return_overflowing_tokens=True,
93+
add_special_tokens=True,
94+
)
95+
return result
96+
97+
if isinstance(dataset, datasets.Dataset) and use_lazy_tokenization:
98+
# Using dataset.map on a non-streaming dataset will automatically perform and cache the transform
99+
tokenized_dataset = dataset.with_transform(tokenize_with_windowing)
100+
else:
101+
tokenized_dataset = dataset.map(
102+
tokenize_with_windowing,
103+
batched=True,
104+
remove_columns=dataset.column_names,
105+
)
106+
107+
return tokenized_dataset, tokenizer
108+
109+
110+
def create_bshd_dataloader(
111+
distributed_config: DistributedConfig,
112+
tokenizer_path: str,
113+
load_dataset_kwargs: dict,
114+
micro_batch_size: int,
115+
num_workers: int = 0,
116+
max_seq_length: int = 8192,
117+
stride: int = 200,
118+
seed: int = 42,
119+
buffer_size: int = 500_000,
120+
use_lazy_tokenization: bool = True,
121+
):
122+
"""Create a BSHD dataloader for genomic sequences using CLM (causal language modeling).
123+
124+
Args:
125+
distributed_config: The distributed configuration.
126+
tokenizer_path: Path to the nucleotide tokenizer directory.
127+
load_dataset_kwargs: Keyword arguments to pass to `load_dataset`.
128+
micro_batch_size: The batch size per device.
129+
num_workers: The number of workers to use for the dataloader.
130+
max_seq_length: The maximum length of sequences (window size).
131+
stride: The stride for windowing (overlap = stride tokens).
132+
seed: The seed to use for the distributed sampler and data collator.
133+
buffer_size: The buffer size for shuffle.
134+
use_lazy_tokenization: Whether to use datasets.set_transform for tokenization.
135+
136+
Returns:
137+
A tuple of (dataloader, dataset_or_sampler).
138+
"""
139+
tokenized_dataset, tokenizer = create_tokenized_dataset(
140+
distributed_config=distributed_config,
141+
tokenizer_path=tokenizer_path,
142+
load_dataset_kwargs=load_dataset_kwargs,
143+
max_seq_length=max_seq_length,
144+
stride=stride,
145+
buffer_size=buffer_size,
146+
use_lazy_tokenization=use_lazy_tokenization,
147+
)
148+
149+
if isinstance(tokenized_dataset, datasets.IterableDataset):
150+
sampler = None
151+
else:
152+
sampler = DistributedSampler(
153+
tokenized_dataset,
154+
rank=distributed_config.rank,
155+
num_replicas=distributed_config.world_size,
156+
seed=seed,
157+
)
158+
159+
# Use DataCollatorForLanguageModeling with mlm=False for CLM
160+
data_collator = DataCollatorForLanguageModeling(
161+
tokenizer=tokenizer,
162+
mlm=False, # Causal language modeling (no masking)
163+
)
164+
165+
train_dataloader = StatefulDataLoader(
166+
tokenized_dataset,
167+
sampler=sampler,
168+
batch_size=micro_batch_size,
169+
collate_fn=data_collator,
170+
num_workers=num_workers,
171+
pin_memory=True,
172+
persistent_workers=num_workers > 0,
173+
)
174+
175+
return train_dataloader, tokenized_dataset if sampler is None else sampler
176+
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import os
18+
from dataclasses import dataclass, field
19+
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
@dataclass(frozen=True)
25+
class DistributedConfig:
26+
"""Class to track distributed ranks and handle basic distributed training setup.
27+
28+
If torch distributed environment variables are not set, we set them to default values for single-process training.
29+
30+
Attributes:
31+
rank: The rank of the process.
32+
local_rank: The local rank of the process.
33+
world_size: The total number of processes.
34+
"""
35+
36+
rank: int = field(default_factory=lambda: int(os.environ.setdefault("RANK", "0")))
37+
local_rank: int = field(default_factory=lambda: int(os.environ.setdefault("LOCAL_RANK", "0")))
38+
world_size: int = field(default_factory=lambda: int(os.environ.setdefault("WORLD_SIZE", "1")))
39+
_master_addr: str = field(default_factory=lambda: os.environ.setdefault("MASTER_ADDR", "localhost"))
40+
_master_port: str = field(default_factory=lambda: os.environ.setdefault("MASTER_PORT", "12355"))
41+
42+
def is_main_process(self) -> bool:
43+
"""This is the global rank 0 process, to be used for wandb logging, etc."""
44+
return self.rank == 0
45+
46+
47+
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import sys
17+
from pathlib import Path
18+
from unittest import mock
19+
20+
import pyarrow as pa
21+
import pyarrow.parquet as pq
22+
import pytest
23+
import torch
24+
from torch.distributed.device_mesh import _mesh_resources, init_device_mesh
25+
26+
27+
sys.path.append(Path(__file__).parent.parent.as_posix())
28+
sys.path.append(Path(__file__).parent.as_posix())
29+
30+
from distributed_config import DistributedConfig
31+
32+
33+
@pytest.fixture
34+
def recipe_path() -> Path:
35+
"""Return the root directory of the recipe."""
36+
return Path(__file__).parent.parent
37+
38+
39+
@pytest.fixture(scope="session")
40+
def mock_genomic_parquet(tmp_path_factory) -> Path:
41+
"""Create a mock genomic sequences parquet file for testing.
42+
43+
This fixture creates a small parquet file with synthetic genomic sequences
44+
that can be used for training tests without relying on external data files.
45+
46+
Returns:
47+
Path to the generated parquet file
48+
"""
49+
tmp_dir = tmp_path_factory.mktemp("data")
50+
parquet_path = tmp_dir / "test_genomic_sequences.parquet"
51+
52+
# Create mock genomic sequences with simple repeating patterns
53+
# These are easy for the model to overfit to, which is perfect for sanity tests
54+
sequences = [
55+
"ATCG" * 300, # 1200 bp - simple ATCG repeat
56+
"AAAA" * 250 + "TTTT" * 250, # 2000 bp - alternating A and T blocks
57+
"GCGC" * 200, # 800 bp - GC repeat
58+
"ACGT" * 400, # 1600 bp - all 4 nucleotides
59+
"TGCA" * 350, # 1400 bp - reverse pattern
60+
]
61+
62+
# Create parquet table with 'sequence' column
63+
table = pa.table({
64+
"sequence": sequences,
65+
})
66+
67+
pq.write_table(table, parquet_path)
68+
return parquet_path
69+
70+
71+
@pytest.fixture(scope="session", autouse=True)
72+
def device_mesh():
73+
"""Create a re-usable device mesh for testing.
74+
75+
This is a "auto-use", session-scope fixture so that a single device mesh is created and used in all tests.
76+
77+
Megatron-FSDP throws issues when re-creating the torch device mesh in the same process, starting in the 25.09 NGC
78+
pytorch container release. To work around this, we create a re-usable device mesh that use in all single-process
79+
tests.
80+
"""
81+
# Initialize the distributed configuration, including creating the distributed process group.
82+
dist_config = DistributedConfig()
83+
device = torch.device(f"cuda:{dist_config.local_rank}")
84+
torch.distributed.init_process_group(backend="nccl", device_id=device)
85+
torch.cuda.set_device(dist_config.local_rank)
86+
device_mesh = init_device_mesh("cuda", mesh_shape=(1, 1), mesh_dim_names=("dp", "tp"))
87+
88+
# Mock these torch.distributed functions so that we re-use the same device mesh, and don't re-create or destroy the
89+
# global process group.
90+
with (
91+
mock.patch("torch.distributed.device_mesh.init_device_mesh", return_value=device_mesh),
92+
mock.patch("torch.distributed.init_process_group", return_value=None),
93+
mock.patch("torch.distributed.destroy_process_group", return_value=None),
94+
):
95+
yield
96+
97+
# At the end of all tests, destroy the process group and clear the device mesh resources.
98+
torch.distributed.destroy_process_group()
99+
_mesh_resources.mesh_stack.clear()
100+
_mesh_resources.child_to_root_mapping.clear()
101+
_mesh_resources.root_to_flatten_mapping.clear()
102+
_mesh_resources.flatten_name_to_root_dims.clear()
103+
_mesh_resources.mesh_dim_group_options.clear()
104+
torch.cuda.empty_cache()
105+
torch.cuda.synchronize()
106+
107+

0 commit comments

Comments
 (0)