Skip to content

Commit

Permalink
Merge branch 'xiny/fix_router_init' into 'main'
Browse files Browse the repository at this point in the history
Fix initialization for gates of router and shared expert

See merge request ADLR/megatron-lm!2238
  • Loading branch information
ko3n1g committed Nov 29, 2024
2 parents 67a50f2 + 1113758 commit 8e9d4dc
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 37 deletions.
11 changes: 2 additions & 9 deletions megatron/core/transformer/moe/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@
import torch

from megatron.core import parallel_state
from megatron.core.tensor_parallel import (
gather_from_sequence_parallel_region,
get_cuda_rng_tracker,
get_data_parallel_rng_tracker_name,
)
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.moe_utils import (
MoEAuxLossAutoScaler,
Expand Down Expand Up @@ -39,14 +35,11 @@ def __init__(self, config: TransformerConfig) -> None:
self.layer_number = None

# Initialize the gate weights.
# TODO: Add support for GPU initialization, which requires updating the golden values.
self.weight = torch.nn.Parameter(
torch.empty((self.config.num_moe_experts, self.config.hidden_size), dtype=torch.float32)
)
if config.perform_initialization:
if get_cuda_rng_tracker().is_initialized():
with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()):
config.init_method(self.weight)
else:
config.init_method(self.weight)
self.weight.data = self.weight.data.to(dtype=config.params_dtype)
setattr(self.weight, 'sequence_parallel', config.sequence_parallel)
Expand Down
26 changes: 4 additions & 22 deletions megatron/core/transformer/moe/shared_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,10 @@
reduce_from_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
)
from megatron.core.tensor_parallel.random import (
get_cuda_rng_tracker,
get_data_parallel_rng_tracker_name,
)
from megatron.core.transformer.mlp import MLP
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_sharded_tensor_for_checkpoint
from megatron.core.utils import is_torch_min_version, make_sharded_tensor_for_checkpoint


class SharedExpertMLP(MLP):
Expand All @@ -46,12 +42,9 @@ def __init__(self, config: TransformerConfig, spec: ModuleSpec):

self.use_shared_expert_gate = spec.params.get("gate", False)
if self.use_shared_expert_gate:
# TODO: Add support for GPU initialization, which requires updating the golden values.
self.gate_weight = torch.nn.Parameter(torch.empty((1, self.config.hidden_size)))
if config.perform_initialization:
if get_cuda_rng_tracker().is_initialized():
with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()):
config.init_method(self.gate_weight)
else:
config.init_method(self.gate_weight)
self.gate_weight.data = self.gate_weight.data.to(dtype=config.params_dtype)
setattr(self.gate_weight, 'sequence_parallel', self.config.sequence_parallel)
Expand Down Expand Up @@ -235,28 +228,17 @@ def get_output(self):
return output


TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
TORCH_LAST = torch.__version__.split(".")[2]


def set_tensor_grad_fn_sequence_sr(tensor, value):
"""
Set sequence_sr for the grad_fn of a tensor to control the backward order.
For older PyTorch version, do nothing (backward order is not changed).
The bigger the value is, the earlier the grad_fn is scheduled.
"""
if (
(TORCH_MAJOR > 2)
or (TORCH_MAJOR == 2 and TORCH_MINOR > 2)
or (TORCH_MAJOR == 2 and TORCH_MINOR == 2 and '+' not in TORCH_LAST)
):
# In NVIDIA PyTorch container 24.01, the PyTorch version is 2.2.0a0+81ea7a4,
# which does not contian the set_sequence_nr commit.
if is_torch_min_version("2.2.0"):
if tensor is not None and tensor.grad_fn is not None:
tensor.grad_fn._set_sequence_nr(value)
else:
warnings.warn(
"WARNING : PyTorch is too old to set sequence_sr and the performance may not "
"optimal. Please use PyTorch >= 2.2.0 for better performance."
"be optimal. Please use PyTorch >= 2.2.0 for better performance."
)
10 changes: 4 additions & 6 deletions megatron/core/transformer/torch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import torch

from megatron.core.transformer import TransformerConfig

TORCH_VERSION = torch.__version__.split('.')
from megatron.core.utils import is_torch_min_version


class WrappedTorchNorm:
Expand Down Expand Up @@ -38,10 +37,9 @@ def __new__(
if config.normalization == "LayerNorm":
norm_cls = torch.nn.LayerNorm
elif config.normalization == "RMSNorm":
version_geq_2_4 = int(TORCH_VERSION[0]) > 2 or (
int(TORCH_VERSION[0]) == 2 and int(TORCH_VERSION[1]) >= 4
)
assert version_geq_2_4, 'Torch RMSNorm requires PyTorch version >= 2.4.0'
assert is_torch_min_version(
"2.4.0a0"
), 'Torch RMSNorm requires PyTorch version >= 2.4.0'

norm_cls = torch.nn.RMSNorm
else:
Expand Down

0 comments on commit 8e9d4dc

Please sign in to comment.