Skip to content

Commit

Permalink
ADLR/megatron-lm!2285 - Support --freeze-LM and --freeze-ViT with ran…
Browse files Browse the repository at this point in the history
…ks that don't have trainable params

Co-authored-by: Jon Barker <[email protected]>
  • Loading branch information
jon-barker and Jon Barker committed Dec 21, 2024
1 parent cf25d44 commit 1468ab0
Show file tree
Hide file tree
Showing 16 changed files with 404 additions and 159 deletions.
4 changes: 3 additions & 1 deletion megatron/core/models/multimodal/llava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import get_batch_on_this_cp_rank, log_single_rank
from megatron.core.utils import log_single_rank

try:
import transformer_engine # pylint: disable=unused-import
Expand Down Expand Up @@ -637,6 +637,8 @@ def _process_embedding_token_parallel(

if self.context_parallel_lm > 1:
# Distribute sequence across CP ranks
from megatron.training.utils import get_batch_on_this_cp_rank

batch = get_batch_on_this_cp_rank(
{
"combined_embeddings": combined_embeddings,
Expand Down
87 changes: 48 additions & 39 deletions megatron/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,48 +262,56 @@ def _get_megatron_optimizer_based_on_param_groups(
Returns:
Instance of MegatronOptimizer.
"""
if config.optimizer == 'adam':
kwargs = {
"params": param_groups,
"lr": config.lr,
"weight_decay": config.weight_decay,
"betas": (config.adam_beta1, config.adam_beta2),
"eps": config.adam_eps,
}

if config.use_precision_aware_optimizer:
kwargs.update(
{
"master_weights": True,
"use_decoupled_grad": True,
"master_weight_dtype": config.main_params_dtype,
"exp_avg_dtype": config.exp_avg_dtype,
"exp_avg_sq_dtype": config.exp_avg_sq_dtype,
}
)
# when freezing sub-models we may have no trainable parameters on a rank and
# hence an empty param_groups. However, we still need to create an optimizer
# for the purposes of grad stats reductions
if param_groups:
if config.optimizer == 'adam':
kwargs = {
"params": param_groups,
"lr": config.lr,
"weight_decay": config.weight_decay,
"betas": (config.adam_beta1, config.adam_beta2),
"eps": config.adam_eps,
}

if config.use_precision_aware_optimizer:
kwargs.update(
{
"master_weights": True,
"use_decoupled_grad": True,
"master_weight_dtype": config.main_params_dtype,
"exp_avg_dtype": config.exp_avg_dtype,
"exp_avg_sq_dtype": config.exp_avg_sq_dtype,
}
)

optimizer = Adam(**kwargs)

def init_state_fn(opt, config=None):
for group in opt.param_groups:
for p in group['params']:
if len(opt.state[p]) == 0:
if config is None or not config.use_precision_aware_optimizer:
opt.state[p]['exp_avg'] = torch.zeros_like(p.data)
opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data)
else:
opt.initialize_state(p)

elif config.optimizer == 'sgd':
optimizer = SGD(
param_groups,
lr=config.lr,
weight_decay=config.weight_decay,
momentum=config.sgd_momentum,
)
init_state_fn = None
optimizer = Adam(**kwargs)

def init_state_fn(opt, config=None):
for group in opt.param_groups:
for p in group['params']:
if len(opt.state[p]) == 0:
if config is None or not config.use_precision_aware_optimizer:
opt.state[p]['exp_avg'] = torch.zeros_like(p.data)
opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data)
else:
opt.initialize_state(p)

elif config.optimizer == 'sgd':
optimizer = SGD(
param_groups,
lr=config.lr,
weight_decay=config.weight_decay,
momentum=config.sgd_momentum,
)
init_state_fn = None
else:
raise Exception('{} optimizer is not supported.'.format(config.optimizer))
else:
raise Exception('{} optimizer is not supported.'.format(config.optimizer))
optimizer = None
init_state_fn = None

# Mixed precision optimizer.
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
Expand Down Expand Up @@ -423,6 +431,7 @@ def get_megatron_optimizer(
model_chunk.overlap_param_gather_with_optimizer_step = (
overlap_param_gather_with_optimizer_step
)

optimizers.append(
_get_megatron_optimizer_based_on_param_groups(
config,
Expand Down
22 changes: 20 additions & 2 deletions megatron/core/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,10 +483,16 @@ def __init__(
for model_chunk in self.model_chunks:
assert self.ddp_config == model_chunk.ddp_config

assert isinstance(
optimizer, Adam
assert (
isinstance(optimizer, Adam) or optimizer is None
), "Only Adam currently supported, due to checkpointing requirements."

# when freezing sub-models we have no real optimizer
# but still need a stub DistributedOptimizer class
if optimizer is None:
self.is_stub_optimizer = True
return

# Model grad buffer ranges.
assert per_model_buffers is not None, "per_model_buffers must be provided"
self.buffers = list(itertools.chain(*per_model_buffers.values()))
Expand Down Expand Up @@ -551,6 +557,8 @@ def __init__(
self.optimizer.param_groups = [g["orig_group"] for g in self.opt_group_ranges]
self.optimizer.load_state_dict(self.optimizer.state_dict())

self.is_stub_optimizer = False

def _get_model_param_range_map(self, param: torch.nn.Parameter):
"""
Given a model param, get the index sub-range of the param that this
Expand Down Expand Up @@ -1635,6 +1643,8 @@ def load_parameter_state(self, filename: str, *, update_legacy_format=False):
Args:
filename (str): path to load parameter state from.
"""
if self.is_stub_optimizer:
return
state_dict = None
if torch.distributed.get_rank(self.data_parallel_group) == 0:
state_dict = torch.load(filename)
Expand All @@ -1653,6 +1663,8 @@ def zero_grad(self, set_to_none: bool = True):
Args:
set_to_none (bool): if true, set grads to None.
"""
if self.is_stub_optimizer:
return
total_groups = [
self.model_float16_groups,
self.model_fp32_groups,
Expand Down Expand Up @@ -1710,6 +1722,8 @@ def _copy_model_grads_to_main_grads(self):
buffer, this method is responsible for copying the updated grads
from the grad buffer to the main shard's grad field.
"""
if self.is_stub_optimizer:
return

# Utility method for copying group grads.
def copy_group_grads(model_groups, shard_main_groups):
Expand Down Expand Up @@ -1748,6 +1762,8 @@ def _copy_main_params_to_model_params(self):
buffer, this method is responsible for copying the updated params
from the main shards into the correct position in the grad buffer.
"""
if self.is_stub_optimizer:
return

# Utility method for copying group params.
def copy_group_params(shard_main_groups, model_groups):
Expand Down Expand Up @@ -1831,6 +1847,8 @@ def _update_fp8_scale_inv_and_amax(self):
If detect FP8 parameters, update their `_scale_inv` and do reduce-max for their
`amax_history`.
"""
if self.is_stub_optimizer:
return
amaxes = []
scales = []
scale_invs = []
Expand Down
Loading

0 comments on commit 1468ab0

Please sign in to comment.