Skip to content

Commit

Permalink
Merge branch 'dnarayanan/fix_param_norm_memory_main' into 'main'
Browse files Browse the repository at this point in the history
Reuse optimizer's main_params to compute param norm in a memory-efficient way

See merge request ADLR/megatron-lm!2483
  • Loading branch information
ko3n1g committed Jan 3, 2025
2 parents 24e0126 + 079dc66 commit f682bd0
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 13 deletions.
17 changes: 17 additions & 0 deletions megatron/core/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,10 @@ def _build_model_and_main_param_groups(
# When using precision-aware optimizer, main params are held by FusedAdam.
shard_main_param = None

# Store handle to main_param.
model_param.main_param = shard_main_param
model_param.main_param_sharded = True

# Add to group.
model_float16_params_this_group.append(model_param)
shard_float16_params_this_group.append(shard_model_param)
Expand Down Expand Up @@ -535,6 +539,19 @@ def __init__(
self.gbuf_ranges.append(self._build_gbuf_range_map(buffer))
self.model_param_gbuf_map = self._build_model_param_gbuf_map(self.gbuf_ranges)

# Add main_param field to each parameter. We will use this fp32 copy to compute
# the param norm.
# For parameters with optimizer state on this rank, None will be overwritten by
# the corresponding sharded main_param tensor.
for param_group in self.optimizer.param_groups:
# For all the parameters in this group.
for param in param_group['params']:
if param.requires_grad:
# fp32 copy only needed for 16-bit parameters.
if param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
param.main_param = None
param.main_param_sharded = True

# Optimizer ranges.
(self.model_param_group_index_map, self.opt_group_ranges) = (
self._build_optimizer_group_ranges(self.optimizer.param_groups, self.gbuf_ranges)
Expand Down
3 changes: 3 additions & 0 deletions megatron/core/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,9 @@ def __init__(
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param

# Store handle to main_param.
param.main_param = main_param

fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param.
if param in self.optimizer.state:
Expand Down
58 changes: 47 additions & 11 deletions megatron/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,63 +66,98 @@ def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES):
return unwrapped_model


def calc_params_l2_norm(model):
def calc_params_l2_norm(model, force_create_fp32_copy=False):
"""Calculate l2 norm of parameters """
args = get_args()
if not isinstance(model, list):
model = [model]
# Seperate moe and dense params
params_data = []
moe_params_data = []
sharded_params_data = []
data_parallel_group = None

for model_chunk in model:
for i, param in enumerate(model_chunk.parameters()):
for param in model_chunk.parameters():
data_parallel_group = get_data_parallel_group_if_dtensor(param, data_parallel_group)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if not (param.requires_grad and is_not_tp_duplicate):
if not is_not_tp_duplicate:
continue
assert is_not_tp_duplicate
if not getattr(param, 'allreduce', True):
# TODO: Implement memory optimization for MoE parameters.
assert param_is_not_shared(param)
param = to_local_if_dtensor(param)
moe_params_data.append(param.data.float() if args.bf16 else param.data)
else:
if param_is_not_shared(param):
param = to_local_if_dtensor(param)
params_data.append(param.data.float() if args.bf16 else param.data)

# Calculate dense param norm
if args.bf16:
if not force_create_fp32_copy and hasattr(param, 'main_param'):
if getattr(param, 'main_param_sharded', False):
if param.main_param is not None:
sharded_params_data.append(param.main_param)
else:
params_data.append(param.main_param)
else:
# Fallback to original logic of making a fp32 copy of the
# parameter if `.main_param` attribute is not available.
params_data.append(param.data.float())
else:
params_data.append(param.data)

# Calculate norm.
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
if len(params_data) > 0:
norm, _ = multi_tensor_applier(
multi_tensor_l2norm,
dummy_overflow_buf,
[params_data],
False # no per-parameter norm
False # no per-parameter norm.
)
norm_2 = norm * norm
else:
norm_2 = torch.tensor([0.0], dtype=torch.float32, device='cuda')
norm_2 = torch.zeros((1,), dtype=torch.float32, device='cuda')

if data_parallel_group is not None:
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=data_parallel_group)

# Sum across all model-parallel GPUs(tensor + pipeline).
# Add norm contribution from params with sharded main_params. These norms need to be
# accumulated across the DP group since the main parameters are sharded because
# of distributed optimizer.
if len(sharded_params_data) > 0:
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
sharded_norm, _ = multi_tensor_applier(
multi_tensor_l2norm,
dummy_overflow_buf,
[sharded_params_data],
False # no per-parameter norm.
)
sharded_norm_2 = sharded_norm * sharded_norm
# Sum over all DP groups.
torch.distributed.all_reduce(
sharded_norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_data_parallel_group()
)
norm_2 += sharded_norm_2

# Sum across all model-parallel GPUs (tensor + pipeline).
torch.distributed.all_reduce(
norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group()
)
# Calculate moe norm

# Add norm contribution from expert layers in MoEs.
if len(moe_params_data) > 0:
moe_norm, _ = multi_tensor_applier(
multi_tensor_l2norm,
dummy_overflow_buf,
[moe_params_data],
False # no per-parameter norm
False # no per-parameter norm.
)
moe_norm_2 = moe_norm * moe_norm
# Sum across expert tensor, model and pipeline parallel GPUs.
Expand All @@ -132,6 +167,7 @@ def calc_params_l2_norm(model):
group=mpu.get_expert_tensor_model_pipeline_parallel_group()
)
norm_2 += moe_norm_2

return norm_2.item() ** 0.5


Expand Down
59 changes: 57 additions & 2 deletions tests/unit_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import os
import time
import urllib.request as req
from types import SimpleNamespace

import mock
import numpy as np
import pytest
import torch

import megatron.core.utils as util
import megatron.training.utils as training_util
from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig
from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer
from megatron.core.transformer import TransformerConfig
from tests.unit_tests.test_utilities import Utils


Expand Down Expand Up @@ -72,7 +78,7 @@ def test_check_param_hashes_across_dp_replicas():
# Setup.
_init_distributed(world, rank)
Utils.initialize_model_parallel()
model = torch.nn.Linear(100, 100, bias=False)
model = torch.nn.Linear(100, 100, bias=False, device='cuda')

# First check case where all replicas agree.
model.weight.data.fill_(1.0)
Expand All @@ -96,7 +102,7 @@ def test_cross_check_param_hashes_across_dp_replicas():
# Setup.
_init_distributed(world, rank)
Utils.initialize_model_parallel()
model = torch.nn.Linear(100, 100, bias=False)
model = torch.nn.Linear(100, 100, bias=False, device='cuda')

# First check case where all replicas agree.
model.weight.data.fill_(1.0)
Expand All @@ -111,6 +117,55 @@ def test_cross_check_param_hashes_across_dp_replicas():
_deinit_distributed()


@pytest.mark.parametrize("use_distributed_optimizer", [False, True])
def test_param_norm(use_distributed_optimizer: bool):
world = int(os.getenv('WORLD_SIZE', '1'))
rank = int(os.getenv('RANK', '0'))

# Setup: distributed, model, mock_args.
_init_distributed(world, rank)
Utils.initialize_model_parallel()
model = torch.nn.Linear(100, 100, bias=False, dtype=torch.bfloat16, device='cuda')
model.requires_grad_(True)
model.weight.data.fill_(1.0)
ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=use_distributed_optimizer)
# Use dummy TransformerConfig which doesn't trigger __post_init__ assertions.
model = DistributedDataParallel(
TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model
)
for param in model.parameters():
assert param.requires_grad
mock_args = SimpleNamespace(bf16=True)

with mock.patch('megatron.training.utils.get_args', new=lambda: mock_args):
# Make sure norm is correct when `main_param` attribute is not available.
assert training_util.calc_params_l2_norm(
model, force_create_fp32_copy=False
) == pytest.approx(100.0)
assert training_util.calc_params_l2_norm(
model, force_create_fp32_copy=True
) == pytest.approx(100.0)

# Make sure norm is correct when `main_param` attribute is available.
optimizer_config = OptimizerConfig(
bf16=True, use_distributed_optimizer=use_distributed_optimizer
)
_ = get_megatron_optimizer(optimizer_config, [model])
for param in model.parameters():
assert hasattr(param, 'main_param')
if use_distributed_optimizer:
assert getattr(param, 'main_param_sharded', False)
assert training_util.calc_params_l2_norm(
model, force_create_fp32_copy=False
) == pytest.approx(100.0)
assert training_util.calc_params_l2_norm(
model, force_create_fp32_copy=True
) == pytest.approx(100.0)

# Teardown.
_deinit_distributed()


def test_straggler_detector():
world = int(os.getenv('WORLD_SIZE', '1'))
rank = int(os.getenv('RANK', '0'))
Expand Down

0 comments on commit f682bd0

Please sign in to comment.