28
28
from megatron .core .optimizer import OptimizerConfig
29
29
from nemo import lightning as nl
30
30
from nemo .collections import llm
31
+ from nemo .lightning import io , resume
32
+ from nemo .lightning .pytorch import callbacks as nl_callbacks
31
33
from nemo .lightning .pytorch .optim import MegatronOptimizerModule
32
34
from nemo .lightning .pytorch .optim .lr_scheduler import CosineAnnealingScheduler
33
- from nemo .lightning .resume import AutoResume
34
35
from nemo .utils import logging
35
36
from pytorch_lightning .callbacks import LearningRateMonitor , RichModelSummary
36
37
from torch .nn import functional as F
@@ -71,6 +72,12 @@ def main(
71
72
wandb_entity : str = "clara-discovery" ,
72
73
create_tensorboard_logger : bool = False ,
73
74
nemo1_init_path : Optional [Path ] = None ,
75
+ restore_from_checkpoint_path : Optional [str ] = None ,
76
+ save_best_checkpoint : bool = True ,
77
+ save_last_checkpoint : bool = True ,
78
+ metric_to_monitor_for_checkpoints : str = "val_loss" ,
79
+ save_top_k : int = 2 ,
80
+ save_every_n_steps : int = 100 ,
74
81
) -> None :
75
82
"""Train a Geneformer model on single cell data.
76
83
@@ -84,9 +91,8 @@ def main(
84
91
wandb_offline (bool): if wandb should happen in offline mode
85
92
num_steps (int): number of steps to train the model for
86
93
limit_val_batches (int): limit the number of validation global batches to this many
87
- val_check_interval (int): number of steps to periodically check the validation loss and save
88
- an updated checkpoint
89
- num_dataset_workers (int): num dataset workers
94
+ val_check_interval (int): number of steps to periodically check the validation loss and save num_dataset_workers (
95
+ int): num dataset workers
90
96
biobert_spec_option (BiobertSpecOption): the biobert spec option (architecture) to use for this run
91
97
lr (float): learning rate
92
98
micro_batch_size (int): micro batch size, from this and parallelism settings we infer the global batch size
@@ -97,8 +103,8 @@ def main(
97
103
resume_if_exists (bool): attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet]
98
104
wandb_entity (str): the group to use for the wandb run, sometimes called a team, could also be your username
99
105
create_tensorboard_logger (bool): create the tensorboard logger
100
-
101
-
106
+ restore_from_checkpoint_path (path): If set, restores the model from the directory passed in. Expects the
107
+ checkpoint to be created by using the ModelCheckpoint class and enable_nemo_ckpt_io=True.
102
108
"""
103
109
# Create the result directory if it does not exist.
104
110
result_dir .mkdir (parents = True , exist_ok = True )
@@ -115,7 +121,7 @@ def main(
115
121
pipeline_model_parallel_size = pipeline_model_parallel_size ,
116
122
ddp = "megatron" ,
117
123
find_unused_parameters = True ,
118
- enable_nemo_ckpt_io = False ,
124
+ ckpt_include_optimizer = True ,
119
125
)
120
126
121
127
wandb_options : Optional [WandbLoggerOptions ] = (
@@ -136,11 +142,15 @@ def main(
136
142
limit_val_batches = limit_val_batches , # This controls upsampling and downsampling
137
143
val_check_interval = val_check_interval , # TODO(@jstjohn) Checkpoint saving is currently broken, fix and change this.
138
144
num_nodes = num_nodes ,
139
- callbacks = [LossLoggingCallback (), RichModelSummary (max_depth = 4 ), LearningRateMonitor ()],
145
+ callbacks = [
146
+ # TODO(@skothenhill-nv) these need to be cleaned up when we have the automatic addition of track_io
147
+ io .track_io (LossLoggingCallback )(),
148
+ io .track_io (RichModelSummary )(max_depth = 4 ),
149
+ io .track_io (LearningRateMonitor )(),
150
+ ],
140
151
plugins = nl .MegatronMixedPrecision (precision = precision , amp_O2 = False ),
141
152
)
142
153
143
- # Preprocess the data to get the tokenizer and median dictionary
144
154
preprocessor = GeneformerPreprocess (
145
155
download_directory = train_data_path ,
146
156
medians_file_path = train_data_path / "medians.json" ,
@@ -224,22 +234,35 @@ def main(
224
234
),
225
235
),
226
236
)
237
+ # Configure our custom Checkpointer
238
+ checkpoint_callback = nl_callbacks .ModelCheckpoint (
239
+ save_best_model = save_best_checkpoint ,
240
+ save_last = save_last_checkpoint ,
241
+ monitor = metric_to_monitor_for_checkpoints , # "val_loss",
242
+ save_top_k = save_top_k ,
243
+ every_n_train_steps = save_every_n_steps ,
244
+ enable_nemo_ckpt_io = True , # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
245
+ async_save = False , # Tries to save asynchronously, previously led to race conditions.
246
+ )
227
247
228
248
# Setup the logger and train the model
229
249
nemo_logger = setup_nemo_lightning_logger (
230
250
root_dir = result_dir ,
231
251
name = experiment_name ,
232
252
initialize_tensorboard_logger = create_tensorboard_logger ,
233
253
wandb_kwargs = wandb_options ,
254
+ ckpt_callback = checkpoint_callback ,
234
255
)
235
-
236
256
llm .train (
237
257
model = model ,
238
258
data = data ,
239
259
trainer = trainer ,
240
260
log = nemo_logger ,
241
- # FIXME @skothenhill this doesn't work yet, but this is probably close to what we are supposed to do
242
- resume = AutoResume (resume_if_exists = resume_if_exists , resume_ignore_no_checkpoint = True ),
261
+ resume = resume .AutoResume (
262
+ path = restore_from_checkpoint_path , # Overrides the path found by resume_if_exists when set.
263
+ resume_if_exists = resume_if_exists , # Looks for the -last checkpoint to continue training.
264
+ resume_ignore_no_checkpoint = True , # When false this will throw an error with no existing checkpoint.
265
+ ),
243
266
)
244
267
245
268
@@ -373,6 +396,39 @@ def main(
373
396
required = False ,
374
397
help = "Path to nemo1 file, if desired to load at init time." ,
375
398
)
399
+ parser .add_argument (
400
+ "--save-best-checkpoint" ,
401
+ action = "store_true" ,
402
+ default = True ,
403
+ help = "Save the best checkpoint based on the metric to monitor." ,
404
+ )
405
+ parser .add_argument (
406
+ "--save-last-checkpoint" ,
407
+ action = "store_true" ,
408
+ default = True ,
409
+ help = "Save the last checkpoint." ,
410
+ )
411
+ parser .add_argument (
412
+ "--metric-to-monitor-for-checkpoints" ,
413
+ type = str ,
414
+ required = False ,
415
+ default = "val_loss" ,
416
+ help = "The metric to monitor for checkpointing." ,
417
+ )
418
+ parser .add_argument (
419
+ "--save-top-k" ,
420
+ type = int ,
421
+ required = False ,
422
+ default = 2 ,
423
+ help = "Save the top k checkpoints." ,
424
+ )
425
+ parser .add_argument (
426
+ "--restore-from-checkpoint-path" ,
427
+ type = Path ,
428
+ required = False ,
429
+ default = None ,
430
+ help = "Path to the checkpoint directory to restore from. Will override `--resume-if-exists` when set." ,
431
+ )
376
432
377
433
# Parse the arguments and pull them out into local variables for ease of future refactor to a
378
434
# config management system.
@@ -398,4 +454,10 @@ def main(
398
454
experiment_name = args .experiment_name ,
399
455
resume_if_exists = args .resume_if_exists ,
400
456
nemo1_init_path = args .nemo1_init_path ,
457
+ restore_from_checkpoint_path = args .restore_from_checkpoint_path ,
458
+ save_best_checkpoint = args .save_best_checkpoint ,
459
+ save_last_checkpoint = args .save_last_checkpoint ,
460
+ metric_to_monitor_for_checkpoints = args .metric_to_monitor_for_checkpoints ,
461
+ save_top_k = args .save_top_k ,
462
+ save_every_n_steps = args .val_check_interval ,
401
463
)
0 commit comments