Skip to content

Commit

Permalink
Merge branch 'dnarayanan/overlap_with_optimizer_step' into 'main'
Browse files Browse the repository at this point in the history
Overlap param all-gather with optimizer step and fix alignment of AGs across pipeline stages

See merge request ADLR/megatron-lm!1874
  • Loading branch information
deepakn94 committed Aug 24, 2024
2 parents e32b60b + 4e38405 commit ef85bc9
Show file tree
Hide file tree
Showing 12 changed files with 311 additions and 96 deletions.
4 changes: 2 additions & 2 deletions examples/gpt3/gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ fp16_lm_cross_entropy: False
distributed_backend: nccl
distributed_timeout_minutes: 10
overlap_grad_reduce: False
delay_grad_reduce: True
align_grad_reduce: True
overlap_param_gather: False
delay_param_gather: False
align_param_gather: False
scatter_gather_tensors_in_pipeline: True
local_rank: null
lazy_mpu_init: null
Expand Down
193 changes: 135 additions & 58 deletions megatron/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
from typing import Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional, Tuple

import torch

Expand Down Expand Up @@ -42,10 +42,13 @@

def _get_param_groups(
model_chunks: List[MegatronModule],
no_weight_decay_cond: Callable,
scale_lr_cond: Callable,
no_weight_decay_cond: Optional[Callable],
scale_lr_cond: Optional[Callable],
lr_mult: float,
use_decoupled_learning_rate: bool,
lr: float,
min_lr: float,
decoupled_lr: Optional[float],
decoupled_min_lr: Optional[float],
) -> List[Dict]:
"""Create parameter groups for optimizer.
Expand All @@ -57,18 +60,23 @@ def _get_param_groups(
Args:
model_chunks (List[MegatronModule]): model chunks to create parameter
groups for.
no_weight_decay_cond (func): function to determine whether a parameter
should not perform weight decay.
scale_lr_cond (func): function to determine whether a parameter
no_weight_decay_cond (func, optional): function to determine whether a
parameter should not perform weight decay.
scale_lr_cond (func, optional): function to determine whether a parameter
should have a scaled learning rate.
lr_mult (float): learning rate multiplier for parameters that
satisfy scale_lr_cond.
use_decoupled_learning_rate (bool): true if using decoupled learning rate.
lr (float): learning rate.
min_lr (float): minimum learning rate.
decoupled_lr (Optional[float]): optional decoupled learning rate.
decoupled_min_lr (Optional[float]): optional decoupled minimum learning rate.
Returns:
List of parameter groups.
"""

use_decoupled_learning_rate = decoupled_lr is not None

# Map (wd_mult, lr_mult, is_expert_parallel, is_decoupled_lr) to params.
params_map = {}
for model_chunk in model_chunks:
Expand Down Expand Up @@ -113,15 +121,22 @@ def _get_param_groups(
param_groups = []
for (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr), params in params_map.items():
assert len(params) > 0
param_groups.append(
{
'params': params,
'wd_mult': wd_mult,
'lr_mult': _lr_mult,
'is_expert_parallel': is_expert_parallel,
'is_decoupled_lr': is_decoupled_lr,
}
)
param_group = {
'params': params,
'wd_mult': wd_mult,
'lr_mult': _lr_mult,
'is_expert_parallel': is_expert_parallel,
'is_decoupled_lr': is_decoupled_lr,
}
param_groups.append(param_group)

param_groups = _update_min_and_max_lr_in_param_groups(
param_groups,
lr=lr,
min_lr=min_lr,
decoupled_lr=decoupled_lr,
decoupled_min_lr=decoupled_min_lr,
)

return param_groups

Expand Down Expand Up @@ -165,6 +180,56 @@ def _update_min_and_max_lr_in_param_groups(
return param_groups


def _get_param_groups_and_buffers(
model_chunks: List[MegatronModule],
model_chunk_offset: int,
config: OptimizerConfig,
no_weight_decay_cond: Optional[Callable],
scale_lr_cond: Optional[Callable],
lr_mult: float,
filter_fn: Callable,
buffer_name: str,
) -> Tuple[List[Dict], Dict[int, ParamAndGradBuffer]]:
"""Returns parameter groups and buffer for optimizer.
Args:
model_chunks (List[MegatronModule]): model chunks to create parameter
groups for.
model_chunk_offset (int): offset of model_chunks in global model_chunks list.
config (OptimizerConfig): optimizer configuration object.
no_weight_decay_cond (func, optional): function to determine whether a
parameter should not perform weight decay.
scale_lr_cond (func, optional): function to determine whether a parameter
should have a scaled learning rate.
lr_mult (float): learning rate multiplier for parameters that
satisfy scale_lr_cond.
lr (float): learning rate.
min_lr (float): minimum learning rate.
filter_fn (callable): filtering function for param_groups.
buffer_name (str): name of buffer.
Returns:
List of parameter groups and dictionary of model chunk IDs to buffers.
"""
param_groups = _get_param_groups(
model_chunks,
no_weight_decay_cond,
scale_lr_cond,
lr_mult,
lr=config.lr,
min_lr=config.min_lr,
decoupled_lr=config.decoupled_lr,
decoupled_min_lr=config.decoupled_min_lr,
)
param_groups = list(filter(filter_fn, param_groups))
buffers = {}
for model_chunk_idx, model_chunk in enumerate(model_chunks):
if hasattr(model_chunk, buffer_name):
buffers[model_chunk_idx + model_chunk_offset] = getattr(model_chunk, buffer_name)

return param_groups, buffers


def _get_megatron_optimizer_based_on_param_groups(
config: OptimizerConfig,
param_groups: List,
Expand All @@ -173,6 +238,7 @@ def _get_megatron_optimizer_based_on_param_groups(
data_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
data_parallel_group_gloo: Optional[torch.distributed.ProcessGroup] = None,
data_parallel_group_idx: Optional[int] = None,
overlap_param_gather_with_optimizer_step: bool = False,
) -> MegatronOptimizer:
"""Get Megatron optimizer based on parameter groups.
Expand All @@ -186,6 +252,8 @@ def _get_megatron_optimizer_based_on_param_groups(
group for distributed optimizer. Defaults to None.
data_parallel_group_idx (int, optional): data-parallel group index for distributed
optimizer. Defaults to None.
overlap_param_gather_with_optimizer_step (bool, optional): if true, overlap parameter
all-gather with optimizer step if using distributed optimizer. Defaults to False.
Returns:
Instance of MegatronOptimizer.
Expand Down Expand Up @@ -255,6 +323,7 @@ def init_state_fn(opt):
data_parallel_group=data_parallel_group,
data_parallel_group_gloo=data_parallel_group_gloo,
data_parallel_group_idx=data_parallel_group_idx,
overlap_param_gather_with_optimizer_step=overlap_param_gather_with_optimizer_step,
)
else:
optimizer = Float16OptimizerWithFloat16Params(*optimizer_args)
Expand Down Expand Up @@ -294,56 +363,64 @@ def get_megatron_optimizer(

log_single_rank(logger, logging.INFO, f'Setting up optimizer with config {config}')

# Collect param groups.
param_groups = _get_param_groups(
model_chunks,
no_weight_decay_cond,
scale_lr_cond,
lr_mult,
use_decoupled_learning_rate=config.decoupled_lr is not None,
)
param_groups = _update_min_and_max_lr_in_param_groups(
param_groups,
lr=config.lr,
min_lr=config.min_lr,
decoupled_lr=config.decoupled_lr,
decoupled_min_lr=config.decoupled_min_lr,
)

# Collect grad buffers for distributed optimizer.
per_model_buffers = {}
per_model_ep_buffers = {}
for model_idx, model_chunk in enumerate(model_chunks):
if hasattr(model_chunk, 'buffers'):
per_model_buffers[model_idx] = model_chunk.buffers
per_model_ep_buffers[model_idx] = model_chunk.expert_parallel_buffers

# Split param groups into dense and MoE params (since data-parallel groups for MoE
# parameters can be different with expert parallelism).
dense_param_groups = list(filter(lambda g: not g['is_expert_parallel'], param_groups))
moe_param_groups = list(filter(lambda g: g['is_expert_parallel'], param_groups))

# Create optimizers.
# Separate out first model chunk if overlapping param AG with optimizer step.
if config.overlap_param_gather_with_optimizer_step:
all_dense_model_chunks = [[model_chunks[0]], model_chunks[1:]]
overlap_param_gather_with_optimizer_step_flags = [True, False]
else:
all_dense_model_chunks = [model_chunks]
overlap_param_gather_with_optimizer_step_flags = [False]
model_parallel_rank = torch.distributed.get_rank(mpu.get_model_parallel_group())
optimizers = [
_get_megatron_optimizer_based_on_param_groups(
config,
param_groups=dense_param_groups,
per_model_buffers=per_model_buffers,
model_parallel_group=mpu.get_model_parallel_group(),
data_parallel_group=mpu.get_data_parallel_group(with_context_parallel=True),
data_parallel_group_gloo=mpu.get_data_parallel_group_gloo(with_context_parallel=True),
data_parallel_group_idx=model_parallel_rank,

optimizers = []
model_chunk_offset = 0
for dense_model_chunks, overlap_param_gather_with_optimizer_step in zip(
all_dense_model_chunks, overlap_param_gather_with_optimizer_step_flags
):
param_groups, buffers = _get_param_groups_and_buffers(
dense_model_chunks,
model_chunk_offset=model_chunk_offset,
config=config,
no_weight_decay_cond=no_weight_decay_cond,
scale_lr_cond=scale_lr_cond,
lr_mult=lr_mult,
filter_fn=lambda g: not g['is_expert_parallel'],
buffer_name='buffers',
)
optimizers.append(
_get_megatron_optimizer_based_on_param_groups(
config,
param_groups=param_groups,
per_model_buffers=buffers,
model_parallel_group=mpu.get_model_parallel_group(),
data_parallel_group=mpu.get_data_parallel_group(with_context_parallel=True),
data_parallel_group_gloo=mpu.get_data_parallel_group_gloo(
with_context_parallel=True
),
data_parallel_group_idx=model_parallel_rank,
overlap_param_gather_with_optimizer_step=overlap_param_gather_with_optimizer_step,
)
)
]
model_chunk_offset += 1

moe_param_groups, moe_buffers = _get_param_groups_and_buffers(
model_chunks,
model_chunk_offset=0,
config=config,
no_weight_decay_cond=no_weight_decay_cond,
scale_lr_cond=scale_lr_cond,
lr_mult=lr_mult,
filter_fn=lambda g: g['is_expert_parallel'],
buffer_name='expert_parallel_buffers',
)
if len(moe_param_groups) > 0:
model_parallel_world_size = torch.distributed.get_world_size(mpu.get_model_parallel_group())
expert_parallel_rank = mpu.get_expert_model_parallel_rank()
optimizers.append(
_get_megatron_optimizer_based_on_param_groups(
config,
param_groups=moe_param_groups,
per_model_buffers=per_model_ep_buffers,
per_model_buffers=moe_buffers,
model_parallel_group=mpu.get_model_parallel_group(with_expert_parallel=True),
data_parallel_group=mpu.get_data_modulo_expert_parallel_group(
with_context_parallel=True
Expand Down
Loading

0 comments on commit ef85bc9

Please sign in to comment.