Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ddp config to improve ESM-2 15B MFU #520

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import List, Optional, Sequence, get_args

from lightning.pytorch.callbacks import LearningRateMonitor, RichModelSummary
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.optimizer import OptimizerConfig
from nemo import lightning as nl
from nemo.collections import llm
Expand Down Expand Up @@ -91,6 +92,10 @@ def main(
hidden_size: int = 1280,
num_attention_heads: int = 20,
ffn_hidden_size: int = 1280 * 4,
overlap_grad_reduce: bool = True,
overlap_param_gather: bool = False, # TODO waiting for a NeMo fix
average_in_collective: bool = True,
grad_reduce_in_fp32: bool = False,
) -> None:
"""Train an ESM2 model on UR data.

Expand Down Expand Up @@ -146,6 +151,10 @@ def main(
hidden_size (int): hidden size
num_attention_heads (int): number of attention heads
ffn_hidden_size (int): feed forward hidden size
overlap_grad_reduce (bool): overlap gradient reduction
overlap_param_gather (bool): overlap parameter gather
average_in_collective (bool): average in collective
grad_reduce_in_fp32 (bool): gradient reduction in fp32
"""
# Create the result directory if it does not exist.
result_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -163,10 +172,18 @@ def main(
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=tensor_model_parallel_size,
pipeline_model_parallel_size=pipeline_model_parallel_size,
ddp="megatron",
pipeline_dtype=get_autocast_dtype(precision),
ddp=DistributedDataParallelConfig(
check_for_nan_in_grad=True,
overlap_grad_reduce=overlap_grad_reduce,
overlap_param_gather=overlap_param_gather,
average_in_collective=average_in_collective,
grad_reduce_in_fp32=grad_reduce_in_fp32,
use_distributed_optimizer=True,
sichu2023 marked this conversation as resolved.
Show resolved Hide resolved
),
find_unused_parameters=True,
gradient_as_bucket_view=True,
ckpt_include_optimizer=True,
# NOTE: there are issues related to async that may occur, most recently observed due to duplicate filenames.
ckpt_async_save=True,
ckpt_parallel_load=True,
)
Expand Down Expand Up @@ -213,7 +230,13 @@ def main(
log_every_n_steps=log_every_n_steps,
num_nodes=num_nodes,
callbacks=callbacks,
plugins=nl.MegatronMixedPrecision(precision=precision),
plugins=nl.MegatronMixedPrecision(
precision=precision,
params_dtype=get_autocast_dtype(precision),
pipeline_dtype=get_autocast_dtype(precision),
grad_reduce_in_fp32=grad_reduce_in_fp32,
autocast_enabled=False,
),
)

tokenizer = get_tokenizer()
Expand Down Expand Up @@ -360,6 +383,10 @@ def train_esm2_entrypoint():
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
ffn_hidden_size=args.ffn_hidden_size,
overlap_grad_reduce=not args.no_overlap_grad_reduce,
overlap_param_gather=args.overlap_param_gather,
average_in_collective=not args.average_in_collective,
grad_reduce_in_fp32=args.grad_reduce_in_fp32,
)


Expand Down Expand Up @@ -667,6 +694,27 @@ def get_parser():
default=4 * 1280,
help="FFN hidden size of the model. Default is 4 * 1280.",
)
# DDP config
parser.add_argument(
"--no-overlap-grad-reduce",
action="store_true",
default=False,
)
parser.add_argument(
"--overlap-param-gather",
action="store_true",
default=False,
) # TODO waiting for a NeMo fix
parser.add_argument(
"--no-average-in-collective",
action="store_true",
default=False,
)
parser.add_argument(
"--grad-reduce-in-fp32",
action="store_true",
default=False,
)
return parser


Expand Down
21 changes: 18 additions & 3 deletions sub-packages/bionemo-llm/src/bionemo/llm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Optional

from lightning.pytorch.callbacks import LearningRateMonitor, RichModelSummary
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.optimizer import OptimizerConfig
from nemo import lightning as nl
from nemo.collections import llm
Expand All @@ -30,6 +31,7 @@
from nemo.utils import logging
from pydantic import BaseModel

from bionemo.core.utils.dtypes import get_autocast_dtype
from bionemo.llm.lightning import BionemoLightningModule, PerplexityLoggingCallback
from bionemo.llm.model.biobert.lightning import biobert_lightning_module
from bionemo.llm.model.lr_scheduler import WarmupAnnealDecayHoldScheduler
Expand Down Expand Up @@ -107,10 +109,17 @@ def setup_trainer(
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=parallel_config.tensor_model_parallel_size,
pipeline_model_parallel_size=parallel_config.pipeline_model_parallel_size,
ddp="megatron",
pipeline_dtype=get_autocast_dtype(training_config.precision),
ddp=DistributedDataParallelConfig(
check_for_nan_in_grad=True,
overlap_grad_reduce=True,
overlap_param_gather=False, # TODO waiting for NeMo fix
average_in_collective=True,
sichu2023 marked this conversation as resolved.
Show resolved Hide resolved
use_distributed_optimizer=True,
),
find_unused_parameters=True,
gradient_as_bucket_view=True,
ckpt_include_optimizer=True,
# NOTE: there are issues related to async that may occur, most recently observed due to duplicate filenames.
ckpt_async_save=True,
ckpt_parallel_load=True,
)
Expand Down Expand Up @@ -151,7 +160,13 @@ def setup_trainer(
val_check_interval=training_config.val_check_interval,
num_nodes=parallel_config.num_nodes,
callbacks=callbacks,
plugins=nl.MegatronMixedPrecision(precision=training_config.precision),
plugins=nl.MegatronMixedPrecision(
precision=training_config.precision,
params_dtype=get_autocast_dtype(training_config.precision),
pipeline_dtype=get_autocast_dtype(training_config.precision),
grad_reduce_in_fp32=False,
autocast_enabled=False,
),
)
return trainer

Expand Down
Loading