Skip to content

Commit 4ee5ea1

Browse files
checkpointing works with autoresume in this PR. Updates to match new structure. change save every n to reuse val check ineterval (#24)
Signed-off-by: Steven Kothen-Hill <[email protected]>
1 parent 71d22a3 commit 4ee5ea1

File tree

4 files changed

+93
-17
lines changed

4 files changed

+93
-17
lines changed

scripts/singlecell/geneformer/pretrain.py

Lines changed: 74 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@
2828
from megatron.core.optimizer import OptimizerConfig
2929
from nemo import lightning as nl
3030
from nemo.collections import llm
31+
from nemo.lightning import io, resume
32+
from nemo.lightning.pytorch import callbacks as nl_callbacks
3133
from nemo.lightning.pytorch.optim import MegatronOptimizerModule
3234
from nemo.lightning.pytorch.optim.lr_scheduler import CosineAnnealingScheduler
33-
from nemo.lightning.resume import AutoResume
3435
from nemo.utils import logging
3536
from pytorch_lightning.callbacks import LearningRateMonitor, RichModelSummary
3637
from torch.nn import functional as F
@@ -71,6 +72,12 @@ def main(
7172
wandb_entity: str = "clara-discovery",
7273
create_tensorboard_logger: bool = False,
7374
nemo1_init_path: Optional[Path] = None,
75+
restore_from_checkpoint_path: Optional[str] = None,
76+
save_best_checkpoint: bool = True,
77+
save_last_checkpoint: bool = True,
78+
metric_to_monitor_for_checkpoints: str = "val_loss",
79+
save_top_k: int = 2,
80+
save_every_n_steps: int = 100,
7481
) -> None:
7582
"""Train a Geneformer model on single cell data.
7683
@@ -84,9 +91,8 @@ def main(
8491
wandb_offline (bool): if wandb should happen in offline mode
8592
num_steps (int): number of steps to train the model for
8693
limit_val_batches (int): limit the number of validation global batches to this many
87-
val_check_interval (int): number of steps to periodically check the validation loss and save
88-
an updated checkpoint
89-
num_dataset_workers (int): num dataset workers
94+
val_check_interval (int): number of steps to periodically check the validation loss and save num_dataset_workers (
95+
int): num dataset workers
9096
biobert_spec_option (BiobertSpecOption): the biobert spec option (architecture) to use for this run
9197
lr (float): learning rate
9298
micro_batch_size (int): micro batch size, from this and parallelism settings we infer the global batch size
@@ -97,8 +103,8 @@ def main(
97103
resume_if_exists (bool): attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet]
98104
wandb_entity (str): the group to use for the wandb run, sometimes called a team, could also be your username
99105
create_tensorboard_logger (bool): create the tensorboard logger
100-
101-
106+
restore_from_checkpoint_path (path): If set, restores the model from the directory passed in. Expects the
107+
checkpoint to be created by using the ModelCheckpoint class and enable_nemo_ckpt_io=True.
102108
"""
103109
# Create the result directory if it does not exist.
104110
result_dir.mkdir(parents=True, exist_ok=True)
@@ -115,7 +121,7 @@ def main(
115121
pipeline_model_parallel_size=pipeline_model_parallel_size,
116122
ddp="megatron",
117123
find_unused_parameters=True,
118-
enable_nemo_ckpt_io=False,
124+
ckpt_include_optimizer=True,
119125
)
120126

121127
wandb_options: Optional[WandbLoggerOptions] = (
@@ -136,11 +142,15 @@ def main(
136142
limit_val_batches=limit_val_batches, # This controls upsampling and downsampling
137143
val_check_interval=val_check_interval, # TODO(@jstjohn) Checkpoint saving is currently broken, fix and change this.
138144
num_nodes=num_nodes,
139-
callbacks=[LossLoggingCallback(), RichModelSummary(max_depth=4), LearningRateMonitor()],
145+
callbacks=[
146+
# TODO(@skothenhill-nv) these need to be cleaned up when we have the automatic addition of track_io
147+
io.track_io(LossLoggingCallback)(),
148+
io.track_io(RichModelSummary)(max_depth=4),
149+
io.track_io(LearningRateMonitor)(),
150+
],
140151
plugins=nl.MegatronMixedPrecision(precision=precision, amp_O2=False),
141152
)
142153

143-
# Preprocess the data to get the tokenizer and median dictionary
144154
preprocessor = GeneformerPreprocess(
145155
download_directory=train_data_path,
146156
medians_file_path=train_data_path / "medians.json",
@@ -224,22 +234,35 @@ def main(
224234
),
225235
),
226236
)
237+
# Configure our custom Checkpointer
238+
checkpoint_callback = nl_callbacks.ModelCheckpoint(
239+
save_best_model=save_best_checkpoint,
240+
save_last=save_last_checkpoint,
241+
monitor=metric_to_monitor_for_checkpoints, # "val_loss",
242+
save_top_k=save_top_k,
243+
every_n_train_steps=save_every_n_steps,
244+
enable_nemo_ckpt_io=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
245+
async_save=False, # Tries to save asynchronously, previously led to race conditions.
246+
)
227247

228248
# Setup the logger and train the model
229249
nemo_logger = setup_nemo_lightning_logger(
230250
root_dir=result_dir,
231251
name=experiment_name,
232252
initialize_tensorboard_logger=create_tensorboard_logger,
233253
wandb_kwargs=wandb_options,
254+
ckpt_callback=checkpoint_callback,
234255
)
235-
236256
llm.train(
237257
model=model,
238258
data=data,
239259
trainer=trainer,
240260
log=nemo_logger,
241-
# FIXME @skothenhill this doesn't work yet, but this is probably close to what we are supposed to do
242-
resume=AutoResume(resume_if_exists=resume_if_exists, resume_ignore_no_checkpoint=True),
261+
resume=resume.AutoResume(
262+
path=restore_from_checkpoint_path, # Overrides the path found by resume_if_exists when set.
263+
resume_if_exists=resume_if_exists, # Looks for the -last checkpoint to continue training.
264+
resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint.
265+
),
243266
)
244267

245268

@@ -373,6 +396,39 @@ def main(
373396
required=False,
374397
help="Path to nemo1 file, if desired to load at init time.",
375398
)
399+
parser.add_argument(
400+
"--save-best-checkpoint",
401+
action="store_true",
402+
default=True,
403+
help="Save the best checkpoint based on the metric to monitor.",
404+
)
405+
parser.add_argument(
406+
"--save-last-checkpoint",
407+
action="store_true",
408+
default=True,
409+
help="Save the last checkpoint.",
410+
)
411+
parser.add_argument(
412+
"--metric-to-monitor-for-checkpoints",
413+
type=str,
414+
required=False,
415+
default="val_loss",
416+
help="The metric to monitor for checkpointing.",
417+
)
418+
parser.add_argument(
419+
"--save-top-k",
420+
type=int,
421+
required=False,
422+
default=2,
423+
help="Save the top k checkpoints.",
424+
)
425+
parser.add_argument(
426+
"--restore-from-checkpoint-path",
427+
type=Path,
428+
required=False,
429+
default=None,
430+
help="Path to the checkpoint directory to restore from. Will override `--resume-if-exists` when set.",
431+
)
376432

377433
# Parse the arguments and pull them out into local variables for ease of future refactor to a
378434
# config management system.
@@ -398,4 +454,10 @@ def main(
398454
experiment_name=args.experiment_name,
399455
resume_if_exists=args.resume_if_exists,
400456
nemo1_init_path=args.nemo1_init_path,
457+
restore_from_checkpoint_path=args.restore_from_checkpoint_path,
458+
save_best_checkpoint=args.save_best_checkpoint,
459+
save_last_checkpoint=args.save_last_checkpoint,
460+
metric_to_monitor_for_checkpoints=args.metric_to_monitor_for_checkpoints,
461+
save_top_k=args.save_top_k,
462+
save_every_n_steps=args.val_check_interval,
401463
)

sub-packages/bionemo-geneformer/src/bionemo/geneformer/tokenizer/gene_tokenizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from copy import deepcopy
2020
from typing import Dict, List, Sequence, Tuple, TypeVar, Union
2121

22+
from nemo.lightning import io
23+
2224
from bionemo.geneformer.tokenizer.label2id_tokenizer import Label2IDTokenizer
2325

2426

@@ -27,7 +29,7 @@
2729
T = TypeVar("T", bound="GeneTokenizer")
2830

2931

30-
class GeneTokenizer(Label2IDTokenizer):
32+
class GeneTokenizer(Label2IDTokenizer, io.IOMixin):
3133
"""Initializes the GeneTokenizer object."""
3234

3335
cls_token: str = "[CLS]"

sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from megatron.core.transformer.transformer_config import TransformerConfig
3333
from megatron.core.transformer.utils import get_linear_layer
3434
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
35-
from nemo.lightning import get_vocab_size
35+
from nemo.lightning import get_vocab_size, io
3636
from nemo.lightning.megatron_parallel import MegatronLossReduction
3737
from torch import Tensor
3838
from torch.optim import Optimizer
@@ -337,7 +337,14 @@ def forward(
337337

338338

339339
@dataclass
340-
class BioBertConfig(BionemoTrainableModelConfig[MegatronBioBertModel, MegatronLossReduction], TransformerConfig): # noqa: D101
340+
class BioBertConfig(
341+
BionemoTrainableModelConfig[MegatronBioBertModel, MegatronLossReduction], TransformerConfig, io.IOMixin
342+
):
343+
"""Config class for BioBert model, responsible for the partial configuration of Transformer models.
344+
345+
`configure_model()` is ultimately called by the LightningModule using PTL lightning module hooks.
346+
"""
347+
341348
# From megatron.core.models.gpt.bert_model.GPTModel
342349
fp16_lm_cross_entropy: bool = False
343350
parallel_output: bool = True

sub-packages/bionemo-llm/src/bionemo/llm/utils/logger_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Any, Dict, Optional, Sequence, TypedDict
1717

1818
from nemo.lightning.nemo_logger import NeMoLogger
19+
from nemo.lightning.pytorch import callbacks as nemo_callbacks
1920
from nemo.utils import logging
2021
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
2122

@@ -41,9 +42,10 @@ class WandbLoggerOptions(TypedDict):
4142

4243
def setup_nemo_lightning_logger(
4344
name: str = "default-name",
44-
root_dir: str = "./results",
45+
root_dir: str | pathlib.Path = "./results",
4546
initialize_tensorboard_logger: bool = False,
4647
wandb_kwargs: Optional[WandbLoggerOptions] = None,
48+
ckpt_callback: Optional[nemo_callbacks.ModelCheckpoint] = None,
4749
**kwargs: Dict[str, Any],
4850
) -> NeMoLogger:
4951
"""Setup the logger for the experiment.
@@ -53,6 +55,8 @@ def setup_nemo_lightning_logger(
5355
root_dir: The root directory to create the `name` directory in for saving run results.
5456
initialize_tensorboard_logger: Whether to initialize the tensorboard logger.
5557
wandb_kwargs: The kwargs for the wandb logger.
58+
ckpt_callback: The checkpoint callback to use, must be a child of the pytorch lightning ModelCheckpoint callback.
59+
NOTE the type annotation in the underlying NeMoCheckpoint constructor is incorrect.
5660
**kwargs: The kwargs for the NeMoLogger.
5761
5862
Returns:
@@ -71,10 +75,11 @@ def setup_nemo_lightning_logger(
7175
tb_logger = None
7276
logging.warning("User-set tensorboard is currently turned off. Internally one may still be set by NeMo2.")
7377
logger: NeMoLogger = NeMoLogger(
74-
dir=root_dir,
78+
dir=str(root_dir),
7579
name=name,
7680
tensorboard=tb_logger,
7781
wandb=wandb_logger,
82+
ckpt=ckpt_callback,
7883
**kwargs,
7984
)
8085
# Needed so that the trainer can find an output directory for the profiler

0 commit comments

Comments
 (0)