|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: LicenseRef-Apache2 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +import logging |
| 17 | +from pathlib import Path |
| 18 | + |
| 19 | +import datasets |
| 20 | +import datasets.distributed |
| 21 | +from torch.utils.data import DistributedSampler |
| 22 | +from torchdata.stateful_dataloader import StatefulDataLoader |
| 23 | +from transformers import AutoTokenizer |
| 24 | +from transformers.data.data_collator import DataCollatorForLanguageModeling |
| 25 | + |
| 26 | +from distributed_config import DistributedConfig |
| 27 | + |
| 28 | + |
| 29 | +logger = logging.getLogger(__name__) |
| 30 | + |
| 31 | + |
| 32 | +def create_tokenized_dataset( |
| 33 | + distributed_config: DistributedConfig, |
| 34 | + tokenizer_path: str, |
| 35 | + load_dataset_kwargs: dict, |
| 36 | + max_seq_length: int = 8192, |
| 37 | + stride: int = 200, |
| 38 | + buffer_size: int = 500_000, |
| 39 | + use_lazy_tokenization: bool = True, |
| 40 | +): |
| 41 | + """Create a tokenized dataset with windowing. |
| 42 | + |
| 43 | + Args: |
| 44 | + distributed_config: The distributed configuration. |
| 45 | + tokenizer_path: Path to the nucleotide tokenizer directory. |
| 46 | + load_dataset_kwargs: Keyword arguments to pass to `load_dataset`. |
| 47 | + max_seq_length: The maximum length of sequences (window size). |
| 48 | + stride: The stride for windowing (overlap = stride tokens). |
| 49 | + buffer_size: The buffer size for shuffle. |
| 50 | + use_lazy_tokenization: Whether to use datasets.set_transform for tokenization. |
| 51 | + |
| 52 | + Returns: |
| 53 | + Tuple of (tokenized_dataset, tokenizer). |
| 54 | + """ |
| 55 | + logger.info(f"Loading dataset with kwargs: {load_dataset_kwargs}") |
| 56 | + dataset = datasets.load_dataset(**load_dataset_kwargs) |
| 57 | + logger.info(f"Loaded dataset: {dataset}") |
| 58 | + |
| 59 | + # Handle DatasetDict (extract "train" split if present) |
| 60 | + if isinstance(dataset, (datasets.DatasetDict, datasets.IterableDatasetDict)): |
| 61 | + if "train" in dataset: |
| 62 | + dataset = dataset["train"] |
| 63 | + else: |
| 64 | + raise ValueError(f"Dataset has splits {list(dataset.keys())} but no 'train' split found. " |
| 65 | + "Please specify split='train' in load_dataset_kwargs or ensure your dataset has a 'train' split.") |
| 66 | + |
| 67 | + # Normalize column names - rename 'nt_sequence' to 'sequence' if present |
| 68 | + # Only do this for non-streaming datasets (streaming datasets don't have column_names attribute) |
| 69 | + if hasattr(dataset, "column_names") and dataset.column_names is not None: |
| 70 | + if "nt_sequence" in dataset.column_names and "sequence" not in dataset.column_names: |
| 71 | + logger.info("Renaming column 'nt_sequence' to 'sequence' for consistency") |
| 72 | + dataset = dataset.rename_column("nt_sequence", "sequence") |
| 73 | + |
| 74 | + if isinstance(dataset, datasets.IterableDataset): |
| 75 | + dataset = datasets.distributed.split_dataset_by_node( |
| 76 | + dataset, |
| 77 | + rank=distributed_config.rank, |
| 78 | + world_size=distributed_config.world_size, |
| 79 | + ) |
| 80 | + dataset = dataset.shuffle(seed=42, buffer_size=buffer_size) |
| 81 | + |
| 82 | + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
| 83 | + |
| 84 | + def tokenize_with_windowing(examples): |
| 85 | + """Tokenize nucleotide sequences with windowing (one-to-many mapping).""" |
| 86 | + # Tokenize with windowing using return_overflowing_tokens |
| 87 | + result = tokenizer( |
| 88 | + examples["sequence"], |
| 89 | + max_length=max_seq_length, |
| 90 | + stride=stride, |
| 91 | + truncation=True, |
| 92 | + return_overflowing_tokens=True, |
| 93 | + add_special_tokens=True, |
| 94 | + ) |
| 95 | + return result |
| 96 | + |
| 97 | + if isinstance(dataset, datasets.Dataset) and use_lazy_tokenization: |
| 98 | + # Using dataset.map on a non-streaming dataset will automatically perform and cache the transform |
| 99 | + tokenized_dataset = dataset.with_transform(tokenize_with_windowing) |
| 100 | + else: |
| 101 | + tokenized_dataset = dataset.map( |
| 102 | + tokenize_with_windowing, |
| 103 | + batched=True, |
| 104 | + remove_columns=dataset.column_names, |
| 105 | + ) |
| 106 | + |
| 107 | + return tokenized_dataset, tokenizer |
| 108 | + |
| 109 | + |
| 110 | +def create_bshd_dataloader( |
| 111 | + distributed_config: DistributedConfig, |
| 112 | + tokenizer_path: str, |
| 113 | + load_dataset_kwargs: dict, |
| 114 | + micro_batch_size: int, |
| 115 | + num_workers: int = 0, |
| 116 | + max_seq_length: int = 8192, |
| 117 | + stride: int = 200, |
| 118 | + seed: int = 42, |
| 119 | + buffer_size: int = 500_000, |
| 120 | + use_lazy_tokenization: bool = True, |
| 121 | +): |
| 122 | + """Create a BSHD dataloader for genomic sequences using CLM (causal language modeling). |
| 123 | + |
| 124 | + Args: |
| 125 | + distributed_config: The distributed configuration. |
| 126 | + tokenizer_path: Path to the nucleotide tokenizer directory. |
| 127 | + load_dataset_kwargs: Keyword arguments to pass to `load_dataset`. |
| 128 | + micro_batch_size: The batch size per device. |
| 129 | + num_workers: The number of workers to use for the dataloader. |
| 130 | + max_seq_length: The maximum length of sequences (window size). |
| 131 | + stride: The stride for windowing (overlap = stride tokens). |
| 132 | + seed: The seed to use for the distributed sampler and data collator. |
| 133 | + buffer_size: The buffer size for shuffle. |
| 134 | + use_lazy_tokenization: Whether to use datasets.set_transform for tokenization. |
| 135 | + |
| 136 | + Returns: |
| 137 | + A tuple of (dataloader, dataset_or_sampler). |
| 138 | + """ |
| 139 | + tokenized_dataset, tokenizer = create_tokenized_dataset( |
| 140 | + distributed_config=distributed_config, |
| 141 | + tokenizer_path=tokenizer_path, |
| 142 | + load_dataset_kwargs=load_dataset_kwargs, |
| 143 | + max_seq_length=max_seq_length, |
| 144 | + stride=stride, |
| 145 | + buffer_size=buffer_size, |
| 146 | + use_lazy_tokenization=use_lazy_tokenization, |
| 147 | + ) |
| 148 | + |
| 149 | + if isinstance(tokenized_dataset, datasets.IterableDataset): |
| 150 | + sampler = None |
| 151 | + else: |
| 152 | + sampler = DistributedSampler( |
| 153 | + tokenized_dataset, |
| 154 | + rank=distributed_config.rank, |
| 155 | + num_replicas=distributed_config.world_size, |
| 156 | + seed=seed, |
| 157 | + ) |
| 158 | + |
| 159 | + # Use DataCollatorForLanguageModeling with mlm=False for CLM |
| 160 | + data_collator = DataCollatorForLanguageModeling( |
| 161 | + tokenizer=tokenizer, |
| 162 | + mlm=False, # Causal language modeling (no masking) |
| 163 | + ) |
| 164 | + |
| 165 | + train_dataloader = StatefulDataLoader( |
| 166 | + tokenized_dataset, |
| 167 | + sampler=sampler, |
| 168 | + batch_size=micro_batch_size, |
| 169 | + collate_fn=data_collator, |
| 170 | + num_workers=num_workers, |
| 171 | + pin_memory=True, |
| 172 | + persistent_workers=num_workers > 0, |
| 173 | + ) |
| 174 | + |
| 175 | + return train_dataloader, tokenized_dataset if sampler is None else sampler |
| 176 | + |
0 commit comments