From 1113758d2419fcdc26d1db78cc502501953862a2 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Fri, 29 Nov 2024 02:06:07 -0800 Subject: [PATCH] ADLR/megatron-lm!2238 - Fix initialization for gates of router and shared expert --- megatron/core/transformer/moe/router.py | 11 ++------ .../core/transformer/moe/shared_experts.py | 26 +++---------------- megatron/core/transformer/torch_norm.py | 10 +++---- 3 files changed, 10 insertions(+), 37 deletions(-) diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index a4d0301716..e03bd5c98e 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -5,11 +5,7 @@ import torch from megatron.core import parallel_state -from megatron.core.tensor_parallel import ( - gather_from_sequence_parallel_region, - get_cuda_rng_tracker, - get_data_parallel_rng_tracker_name, -) +from megatron.core.tensor_parallel import gather_from_sequence_parallel_region from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.moe.moe_utils import ( MoEAuxLossAutoScaler, @@ -39,14 +35,11 @@ def __init__(self, config: TransformerConfig) -> None: self.layer_number = None # Initialize the gate weights. + # TODO: Add support for GPU initialization, which requires updating the golden values. self.weight = torch.nn.Parameter( torch.empty((self.config.num_moe_experts, self.config.hidden_size), dtype=torch.float32) ) if config.perform_initialization: - if get_cuda_rng_tracker().is_initialized(): - with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()): - config.init_method(self.weight) - else: config.init_method(self.weight) self.weight.data = self.weight.data.to(dtype=config.params_dtype) setattr(self.weight, 'sequence_parallel', config.sequence_parallel) diff --git a/megatron/core/transformer/moe/shared_experts.py b/megatron/core/transformer/moe/shared_experts.py index c2d9c188e3..1d4b2a628f 100644 --- a/megatron/core/transformer/moe/shared_experts.py +++ b/megatron/core/transformer/moe/shared_experts.py @@ -17,14 +17,10 @@ reduce_from_tensor_model_parallel_region, reduce_scatter_to_sequence_parallel_region, ) -from megatron.core.tensor_parallel.random import ( - get_cuda_rng_tracker, - get_data_parallel_rng_tracker_name, -) from megatron.core.transformer.mlp import MLP from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import make_sharded_tensor_for_checkpoint +from megatron.core.utils import is_torch_min_version, make_sharded_tensor_for_checkpoint class SharedExpertMLP(MLP): @@ -46,12 +42,9 @@ def __init__(self, config: TransformerConfig, spec: ModuleSpec): self.use_shared_expert_gate = spec.params.get("gate", False) if self.use_shared_expert_gate: + # TODO: Add support for GPU initialization, which requires updating the golden values. self.gate_weight = torch.nn.Parameter(torch.empty((1, self.config.hidden_size))) if config.perform_initialization: - if get_cuda_rng_tracker().is_initialized(): - with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()): - config.init_method(self.gate_weight) - else: config.init_method(self.gate_weight) self.gate_weight.data = self.gate_weight.data.to(dtype=config.params_dtype) setattr(self.gate_weight, 'sequence_parallel', self.config.sequence_parallel) @@ -235,28 +228,17 @@ def get_output(self): return output -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) -TORCH_LAST = torch.__version__.split(".")[2] - - def set_tensor_grad_fn_sequence_sr(tensor, value): """ Set sequence_sr for the grad_fn of a tensor to control the backward order. For older PyTorch version, do nothing (backward order is not changed). The bigger the value is, the earlier the grad_fn is scheduled. """ - if ( - (TORCH_MAJOR > 2) - or (TORCH_MAJOR == 2 and TORCH_MINOR > 2) - or (TORCH_MAJOR == 2 and TORCH_MINOR == 2 and '+' not in TORCH_LAST) - ): - # In NVIDIA PyTorch container 24.01, the PyTorch version is 2.2.0a0+81ea7a4, - # which does not contian the set_sequence_nr commit. + if is_torch_min_version("2.2.0"): if tensor is not None and tensor.grad_fn is not None: tensor.grad_fn._set_sequence_nr(value) else: warnings.warn( "WARNING : PyTorch is too old to set sequence_sr and the performance may not " - "optimal. Please use PyTorch >= 2.2.0 for better performance." + "be optimal. Please use PyTorch >= 2.2.0 for better performance." ) diff --git a/megatron/core/transformer/torch_norm.py b/megatron/core/transformer/torch_norm.py index 7a3a7cb9b0..5fcb74da8b 100644 --- a/megatron/core/transformer/torch_norm.py +++ b/megatron/core/transformer/torch_norm.py @@ -2,8 +2,7 @@ import torch from megatron.core.transformer import TransformerConfig - -TORCH_VERSION = torch.__version__.split('.') +from megatron.core.utils import is_torch_min_version class WrappedTorchNorm: @@ -38,10 +37,9 @@ def __new__( if config.normalization == "LayerNorm": norm_cls = torch.nn.LayerNorm elif config.normalization == "RMSNorm": - version_geq_2_4 = int(TORCH_VERSION[0]) > 2 or ( - int(TORCH_VERSION[0]) == 2 and int(TORCH_VERSION[1]) >= 4 - ) - assert version_geq_2_4, 'Torch RMSNorm requires PyTorch version >= 2.4.0' + assert is_torch_min_version( + "2.4.0a0" + ), 'Torch RMSNorm requires PyTorch version >= 2.4.0' norm_cls = torch.nn.RMSNorm else: