Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling LR scaling for a specific layer (ex. down-projection...) during pretraining #1262

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion megatron/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def _get_param_groups(
Creates parameter groups based on weight decay condition (regularized vs
non regularized), learning rate scale condition (lr vs lr_mult * lr),
and whether it is expert parameters. scale_lr_cond is used during finetuning
where head of the network requires a scaled version of the base learning rate.
where head of the network can have a scaled version of the base learning rate or
during pre-training where down-projection layer (linear_fc2) can have a lower learning rate.

Args:
model_chunks (List[MegatronModule]): model chunks to create parameter
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class TransformerConfig(ModelParallelConfig):

# @jcasper should we keep this option?
apply_residual_connection_post_layernorm: bool = False
"""If True, uses the original BERT residule connection ordering."""
"""If True, uses the original BERT residual connection ordering."""

layernorm_epsilon: float = 1e-5
"""Epsilon value for any LayerNorm operations."""
Expand Down
8 changes: 6 additions & 2 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,12 @@ def _add_learning_rate_args(parser):
group.add_argument('--decoupled-min-lr', type=float, default=None,
help='Minimum value for learning rate for the input and output layer. The scheduler'
'clip values below this threshold')
group.add_argument('--scale-lr-layer', type=str, default=None,
help='Scale learning rate for the specified layer.'
'E.g. --scale-lr-layer "linear_fc2" to scale lr for down-proj layer (during pretraining or finetuning).'
'Or, --scale-lr-layer "head" to scale lr for lm-head (during pretraining or finetuning).')
group.add_argument('--lr-multiplier', type=float, default=1.0,
help='Learning rate multiplier for the specified layer in scale-lr-layer.')

return parser

Expand Down Expand Up @@ -1821,8 +1827,6 @@ def _add_vision_args(parser):
group.add_argument('--no-data-sharding', action='store_false',
help='Disable data sharding.',
dest='data_sharding')
group.add_argument('--head-lr-mult', type=float, default=1.0,
help='learning rate multiplier for head during finetuning')

# pretraining type and backbone selection`
group.add_argument('--vision-pretraining', action='store_true',
Expand Down
5 changes: 4 additions & 1 deletion megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,10 @@ def pretrain(
timers('model-and-optimizer-setup', log_level=0).start(barrier=True)
app_metrics['app_build_optimizer_start_time'] = one_logger_utils.get_timestamp_in_ms()
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
model_provider, model_type, checkpointing_context=checkpointing_context)
model_provider, model_type,
scale_lr_cond=(lambda name, param: args.scale_lr_layer in name) if args.scale_lr_layer else None,
lr_mult=args.lr_multiplier,
checkpointing_context=checkpointing_context)

timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate '
Expand Down
4 changes: 2 additions & 2 deletions tasks/vision/finetune_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ def finetune(
setup_model_and_optimizer(
model_provider,
model_type,
scale_lr_cond=lambda name, param: ".head." in name,
lr_mult=args.head_lr_mult)
scale_lr_cond=(lambda name, param: args.scale_lr_layer in name) if args.scale_lr_layer else None,
lr_mult=args.lr_multiplier)
timers("model and optimizer").stop()

# If pretrained checkpoint is provided and we have not trained for
Expand Down