Skip to content

Commit 112ad15

Browse files
authored
Yangzhang/codonfm 10b (#1305)
### Description <!-- Provide a detailed description of the changes in this PR --> add fsdp2 support for TE #### Usage <!--- How does a user interact with the changed code --> ```python python -m src.runner pretrain \ --exp_name "$exp_name" \ --model_name encodon_80m \ --data_path /data/ncbi/processed_unfiltered/ \ --process_item mlm_memmap \ --dataset_name CodonMemmapDataset \ --lr $learning_rate \ --num_gpus $num_gpus \ --num_nodes $num_nodes \ --train_batch_size $train_batch_size \ --val_batch_size $val_batch_size \ --num_workers $num_workers \ --collate_fn thd \ --attn_input_format thd \ --use_transformer_engine \ --bf16 \ --split_name_prefix nopathogen \ --use_transformer_engine \ --checkpoints_dir results/${exp_name}/checkpoints/ \ --out_dir results/${exp_name}/ \ --enable_fsdp ``` ### Type of changes <!-- Mark the relevant option with an [x] --> - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Refactor - [ ] Documentation update - [ ] Other (please describe): ### CI Pipeline Configuration Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run. - [ciflow:skip](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:skip) - Skip all CI tests for this PR - [ciflow:notebooks](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:notebooks) - Run Jupyter notebooks execution tests for bionemo2 - [ciflow:slow](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:slow) - Run slow single GPU integration tests marked as @pytest.mark.slow for bionemo2 - [ciflow:all](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all) - Run all tests (unit tests, slow tests, and notebooks) for bionemo2. This label can be used to enforce running tests for all bionemo2. - [ciflow:all-recipes](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all-recipes) - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes. Unit tests marked as `@pytest.mark.multi_gpu` or `@pytest.mark.distributed` are not run in the PR pipeline. For more details, see [CONTRIBUTING](CONTRIBUTING.md) > [!NOTE] > By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage. #### Authorizing CI Runs We use [copy-pr-bot](https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/#automation) to manage authorization of CI runs on NVIDIA's compute resources. - If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123) - If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an `/ok to test` comment on the pull request to trigger CI. This will need to be done for each new commit. ### Pre-submit Checklist <!--- Ensure all items are completed before submitting --> - [x] I have tested these changes locally - [ ] I have updated the documentation accordingly - [x] I have added/updated tests as needed - [x] All existing tests pass successfully --------- Signed-off-by: Yang Zhang <[email protected]>
1 parent da4686a commit 112ad15

File tree

14 files changed

+525
-47
lines changed

14 files changed

+525
-47
lines changed

bionemo-recipes/recipes/codonfm_ptl_te/experiment_scripts/pretraining/encodon_filtered/mlm/encodon_1b.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ val_batch_size=4
1111
effective_batch_size=$((train_batch_size * num_gpus * num_nodes))
1212
num_workers=12
1313

14+
exp_name="encodon_1b_latest_${learning_rate}_${effective_batch_size}_nopathogen"
15+
1416
# Note if you would like to use WandB please add --enable_wandb, --project_name and --entity.
1517

1618
# - run
1719
python -m src.runner pretrain \
18-
--exp_name encodon_1b_baseline_${learning_rate}_${effective_batch_size}_nopathogen \
20+
--exp_name "$exp_name" \
1921
--model_name encodon_1b \
2022
--data_path /data/ncbi/processed_unfiltered/ \
2123
--process_item mlm_memmap \
@@ -31,4 +33,5 @@ python -m src.runner pretrain \
3133
--num_workers $num_workers \
3234
--bf16 \
3335
--split_name_prefix nopathogen \
34-
--out_dir /workspace/codonfm/results \
36+
--checkpoints_dir results/${exp_name}/checkpoints/ \
37+
--out_dir results/${exp_name}/ \

bionemo-recipes/recipes/codonfm_ptl_te/experiment_scripts/pretraining/encodon_filtered/mlm/encodon_600m.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,5 @@ python -m src.runner pretrain \
3232
--use_transformer_engine \
3333
--bf16 \
3434
--split_name_prefix nopathogen \
35-
--checkpoints_dir results/checkpoints/${exp_name} \
35+
--checkpoints_dir results/${exp_name}/checkpoints/ \
36+
--out_dir results/${exp_name}/ \

bionemo-recipes/recipes/codonfm_ptl_te/experiment_scripts/pretraining/encodon_filtered/mlm/encodon_80m.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,5 @@ python -m src.runner pretrain \
3131
--bf16 \
3232
--split_name_prefix nopathogen \
3333
--use_transformer_engine \
34-
--checkpoints_dir results/checkpoints/${exp_name} \
34+
--checkpoints_dir results/${exp_name}/checkpoints/ \
35+
--out_dir results/${exp_name}/

bionemo-recipes/recipes/codonfm_ptl_te/requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ natsort==8.4.0
148148
nbclient==0.10.2
149149
nbconvert==7.16.6
150150
nbformat==5.10.4
151-
nemo_run @ git+https://github.com/NVIDIA/NeMo-Run.git@3ec63b951a3cf3733358f3ed2a55e87bf466d263
152151
nest-asyncio==1.6.0
153152
networkx==3.5
154153
ninja==1.11.1.4

bionemo-recipes/recipes/codonfm_ptl_te/src/config.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,15 @@
3333
from src.models.encodon_pl import EncodonPL
3434
from src.models.encodon_te_pl import EncodonTEPL
3535
from src.tokenizer import Tokenizer
36+
from src.utils.fsdp_config import get_fsdp_strategy
3637
from src.utils.grad_norm_callback import GradientNormLogger
3738
from src.utils.pred_writer import PredWriter
3839
from src.utils.scheduler import linear_scheduler_with_warmup_lr_lambda
3940
from src.utils.timer import StepTimingCallback
4041

4142

4243
# Datasets
43-
def get_dataset_config(args: Any, process_item_cfg: fdl.Partial) -> fdl.Config:
44+
def get_dataset_config(args: Any, process_item_cfg: fdl.Partial) -> fdl.Config: # noqa: C901
4445
"""Builds the dataset configuration."""
4546
class_name = args.dataset_name
4647
if class_name == "CodonMemmapDataset":
@@ -49,6 +50,8 @@ def get_dataset_config(args: Any, process_item_cfg: fdl.Partial) -> fdl.Config:
4950
module_path = "src.data.mutation_dataset"
5051
elif class_name == "CodonBertDataset":
5152
module_path = "src.data.codon_bert_dataset"
53+
elif class_name == "SimpleCodonDataset":
54+
module_path = "src.data.simple_codon_dataset"
5255
else:
5356
raise ValueError(f"Unknown dataset name: {class_name}")
5457

@@ -94,6 +97,12 @@ def get_dataset_config(args: Any, process_item_cfg: fdl.Partial) -> fdl.Config:
9497
tokenizer=tokenizer_cfg,
9598
process_item=process_item_cfg,
9699
)
100+
elif class_name == "SimpleCodonDataset":
101+
# SimpleCodonDataset doesn't need data_path, tokenizer, or most other args
102+
dataset_cfg = fdl.Partial(
103+
dataset_class,
104+
process_item=process_item_cfg,
105+
)
97106
else:
98107
print(f"Warning: Using generic config for dataset '{args.dataset_name}'.")
99108
dataset_cfg = fdl.Partial(dataset_class, **common_args)
@@ -114,6 +123,7 @@ def get_callbacks_config(args: Any) -> Dict[str, fdl.Config]:
114123
mode="min",
115124
save_top_k=1,
116125
auto_insert_metric_name=False,
126+
enable_version_counter=False,
117127
),
118128
"early_stopping": fdl.Config(
119129
EarlyStopping,
@@ -217,6 +227,12 @@ def get_logger_config(args: Any) -> fdl.Config:
217227

218228
# Model
219229
MODEL_ARCHITECTURES: Dict[str, Dict[str, Any]] = {
230+
"encodon_200k": {
231+
"hidden_size": 128,
232+
"intermediate_size": 512,
233+
"num_attention_heads": 4,
234+
"num_hidden_layers": 2,
235+
},
220236
"encodon_80m": {
221237
"hidden_size": 1024,
222238
"intermediate_size": 4096,
@@ -235,6 +251,12 @@ def get_logger_config(args: Any) -> fdl.Config:
235251
"num_attention_heads": 16,
236252
"num_hidden_layers": 18,
237253
},
254+
"encodon_10b": {
255+
"hidden_size": 5120,
256+
"intermediate_size": 20480,
257+
"num_attention_heads": 40,
258+
"num_hidden_layers": 34,
259+
},
238260
}
239261

240262

@@ -304,12 +326,24 @@ def get_model_config(args: Any) -> fdl.Config:
304326
# Trainer
305327
def get_trainer_config(args: Any) -> Dict[str, Any]:
306328
"""Builds the trainer configuration arguments."""
329+
# Configure strategy based on args
330+
if args.enable_fsdp:
331+
# Use proper FSDP/FSDP2 strategy with auto-wrap policy
332+
# This ensures FSDP uses LESS memory than DDP
333+
strategy = get_fsdp_strategy(
334+
cpu_offload=getattr(args, "fsdp_cpu_offload", False), activation_checkpointing=False, use_fsdp2=True
335+
)
336+
elif args.mode == "finetune":
337+
strategy = "ddp_find_unused_parameters_true"
338+
else:
339+
strategy = "ddp"
340+
307341
trainer_kwargs = dict( # noqa: C408
308342
num_nodes=args.num_nodes,
309343
devices=args.num_gpus,
310344
max_steps=args.max_steps,
311345
default_root_dir=args.out_dir,
312-
strategy="ddp" if args.mode != "finetune" else "ddp_find_unused_parameters_true",
346+
strategy=strategy,
313347
precision="bf16-mixed" if getattr(args, "bf16", False) else "32-true",
314348
limit_val_batches=args.limit_val_batches,
315349
log_every_n_steps=args.log_every_n_steps,
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
"""Simple synthetic codon dataset for testing and demo purposes.
18+
19+
This dataset generates random sequences on-the-fly without requiring any data files.
20+
Useful for quick testing, debugging, and development without setting up real data.
21+
"""
22+
23+
from typing import Callable, Optional
24+
25+
import numpy as np
26+
from torch.utils.data import Dataset
27+
28+
from src.data.metadata import MetadataFields
29+
30+
31+
class SimpleCodonDataset(Dataset):
32+
"""Simple synthetic dataset that generates random codon sequences.
33+
34+
This dataset is useful for:
35+
- Quick testing without setting up data files
36+
- Debugging model training loops
37+
- Development and prototyping
38+
- FSDP/distributed training tests
39+
40+
Args:
41+
num_samples (int): Number of samples in the dataset. Defaults to 1000.
42+
seq_length (int): Length of each sequence. Defaults to 2048.
43+
vocab_size (int): Size of the vocabulary. Defaults to 69 (codon vocabulary size).
44+
split_name (str): Split name ('train', 'val', 'test', or 'all'). Defaults to 'all'.
45+
train_ratio (float): Ratio of training samples. Defaults to 0.8.
46+
val_ratio (float): Ratio of validation samples. Defaults to 0.1.
47+
process_item (Callable, optional): Function to process items. Not used in this dataset.
48+
seed (int, optional): Random seed for reproducibility.
49+
"""
50+
51+
def __init__(
52+
self,
53+
num_samples: int = 10000,
54+
seq_length: int = 2048,
55+
vocab_size: int = 69,
56+
split_name: str = "all",
57+
train_ratio: float = 0.8,
58+
val_ratio: float = 0.1,
59+
process_item: Optional[Callable] = None,
60+
seed: Optional[int] = None,
61+
**kwargs,
62+
):
63+
"""Initialize the SimpleCodonDataset."""
64+
self.num_samples = num_samples
65+
self.seq_length = seq_length
66+
self.vocab_size = vocab_size
67+
self.split_name = split_name
68+
self.train_ratio = train_ratio
69+
self.val_ratio = val_ratio
70+
self.test_ratio = 1.0 - train_ratio - val_ratio
71+
self.process_item = process_item
72+
self.seed = seed
73+
74+
# Calculate split boundaries
75+
train_end = int(num_samples * train_ratio)
76+
val_end = train_end + int(num_samples * val_ratio)
77+
78+
# Set the actual samples for this split
79+
if split_name == "train":
80+
self.start_idx = 0
81+
self.end_idx = train_end
82+
elif split_name == "val":
83+
self.start_idx = train_end
84+
self.end_idx = val_end
85+
elif split_name == "test":
86+
self.start_idx = val_end
87+
self.end_idx = num_samples
88+
else: # 'all'
89+
self.start_idx = 0
90+
self.end_idx = num_samples
91+
92+
self.actual_num_samples = self.end_idx - self.start_idx
93+
94+
def __len__(self):
95+
"""Return the number of samples in this split."""
96+
return self.actual_num_samples
97+
98+
def __getitem__(self, idx):
99+
"""Generate a random codon sequence sample.
100+
101+
Args:
102+
idx: Index of the sample to retrieve.
103+
104+
Returns:
105+
Dictionary containing:
106+
- INPUT_IDS: Random token IDs (numpy array)
107+
- LABELS: Random labels for MLM (numpy array)
108+
- ATTENTION_MASK: All ones (no padding) (numpy array)
109+
- INPUT_MASK: All ones (no masking) (numpy array)
110+
"""
111+
# Use deterministic random generation based on seed and index
112+
if self.seed is not None:
113+
rng = np.random.default_rng(self.seed + self.start_idx + idx)
114+
else:
115+
rng = np.random.default_rng()
116+
117+
return {
118+
MetadataFields.INPUT_IDS: rng.integers(0, self.vocab_size, size=self.seq_length, dtype=np.int64),
119+
MetadataFields.LABELS: rng.integers(0, self.vocab_size, size=self.seq_length, dtype=np.int64),
120+
MetadataFields.ATTENTION_MASK: np.ones(self.seq_length, dtype=bool),
121+
MetadataFields.INPUT_MASK: np.ones(self.seq_length, dtype=bool),
122+
}
123+
124+
def get_train(self, process_item: Optional[Callable] = None) -> "SimpleCodonDataset":
125+
"""Return the training split of the dataset.
126+
127+
Args:
128+
process_item: Optional processing function (not used in this dataset).
129+
130+
Returns:
131+
SimpleCodonDataset instance for the training split.
132+
"""
133+
return SimpleCodonDataset(
134+
num_samples=self.num_samples,
135+
seq_length=self.seq_length,
136+
vocab_size=self.vocab_size,
137+
split_name="train",
138+
train_ratio=self.train_ratio,
139+
val_ratio=self.val_ratio,
140+
process_item=process_item or self.process_item,
141+
seed=self.seed,
142+
)
143+
144+
def get_validation(self, process_item: Optional[Callable] = None) -> "SimpleCodonDataset":
145+
"""Return the validation split of the dataset.
146+
147+
Args:
148+
process_item: Optional processing function (not used in this dataset).
149+
150+
Returns:
151+
SimpleCodonDataset instance for the validation split.
152+
"""
153+
return SimpleCodonDataset(
154+
num_samples=self.num_samples,
155+
seq_length=self.seq_length,
156+
vocab_size=self.vocab_size,
157+
split_name="val",
158+
train_ratio=self.train_ratio,
159+
val_ratio=self.val_ratio,
160+
process_item=process_item or self.process_item,
161+
seed=self.seed,
162+
)
163+
164+
def get_test(self, process_item: Optional[Callable] = None) -> "SimpleCodonDataset":
165+
"""Return the test split of the dataset.
166+
167+
Args:
168+
process_item: Optional processing function (not used in this dataset).
169+
170+
Returns:
171+
SimpleCodonDataset instance for the test split.
172+
"""
173+
return SimpleCodonDataset(
174+
num_samples=self.num_samples,
175+
seq_length=self.seq_length,
176+
vocab_size=self.vocab_size,
177+
split_name="test",
178+
train_ratio=self.train_ratio,
179+
val_ratio=self.val_ratio,
180+
process_item=process_item or self.process_item,
181+
seed=self.seed,
182+
)

0 commit comments

Comments
 (0)