From 98b43c91d004dec254f1610d9cffae8aff8550f3 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Wed, 4 Sep 2024 01:05:24 -0700 Subject: [PATCH] ADLR/megatron-lm!1935 - Fix TE versions --- .../custom_layers/transformer_engine.py | 17 +++++++++++++++-- megatron/core/transformer/transformer_config.py | 5 +++-- megatron/training/arguments.py | 2 +- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/megatron/core/transformer/custom_layers/transformer_engine.py b/megatron/core/transformer/custom_layers/transformer_engine.py index 4d73995bbd..6a265c5b3c 100644 --- a/megatron/core/transformer/custom_layers/transformer_engine.py +++ b/megatron/core/transformer/custom_layers/transformer_engine.py @@ -2,6 +2,7 @@ import dataclasses import os +import warnings from importlib.metadata import version from typing import Callable @@ -26,6 +27,8 @@ def get_te_version(): + """Get TE version from __version__; if not available use pip's. Use caching.""" + def get_te_version_str(): if hasattr(te, '__version__'): return str(te.__version__) @@ -50,6 +53,7 @@ def _get_extra_te_kwargs(config: TransformerConfig): def condition_init_method(config, init_method): + """Condition TE init_method on config.perform_initialization.""" return init_method if config.perform_initialization else (lambda w: None) @@ -168,6 +172,7 @@ def __init__( ) def forward(self, x): + """Forward.""" _is_first_microbatch = ( None if self.disable_parameter_transpose_cache else self.is_first_microbatch ) @@ -287,6 +292,7 @@ def __init__( ) def forward(self, x): + """Forward.""" _is_first_microbatch = ( None if self.disable_parameter_transpose_cache else self.is_first_microbatch ) @@ -508,6 +514,7 @@ def forward( attn_mask_type: AttnMaskType, packed_seq_params: PackedSeqParams = None, ): + """Forward.""" packed_seq_kwargs = ( dataclasses.asdict(packed_seq_params) if packed_seq_params is not None else {} ) @@ -644,6 +651,7 @@ def __init__( setattr(param, 'allreduce', not (is_expert and self.expert_parallel)) def forward(self, x, m_splits): + """Forward.""" _is_first_microbatch = ( None if self.disable_parameter_transpose_cache else self.is_first_microbatch ) @@ -824,10 +832,13 @@ def __init__( if _te_version >= packaging.version.Version("1.6.0.dev0"): extra_kwargs["fp8_dpa"] = config.fp8_dot_product_attention extra_kwargs["fp8_mha"] = config.fp8_multi_head_attention + if _te_version < packaging.version.Version("1.8.0"): + extra_kwargs["interval"] = config.fp8_interval + elif config.fp8_interval != 1: + warnings.warn("fp8_interval is deprecated and ignored from Transformer-Engine v1.8.0.") super().__init__( margin=config.fp8_margin, - interval=config.fp8_interval, fp8_format=fp8_format, amax_compute_algo=config.fp8_amax_compute_algo, amax_history_len=config.fp8_amax_history_len, @@ -847,6 +858,7 @@ def te_checkpoint( context_mask, rotary_pos_emb, ): + """Checkpointing with Transformer-Engine.""" from transformer_engine.pytorch.distributed import checkpoint if _te_version >= packaging.version.Version("1.5.0"): @@ -894,7 +906,8 @@ def te_checkpoint( def get_cpu_offload_context( enabled, num_layers, model_layers, activation_offloading, weight_offloading ): - if _te_version > packaging.version.Version("1.8.0"): + """Get CPU offload context and sync function.""" + if _te_version >= packaging.version.Version("1.10.0.dev0"): context, sync_func = _get_cpu_offload_context( enabled, num_layers, model_layers, activation_offloading, weight_offloading ) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 1d1b55592a..4bf393cdf6 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -158,7 +158,6 @@ class TransformerConfig(ModelParallelConfig): # activation recomputation #################### recompute_granularity: str = None - recompute_granularity: str = None """Determines which type of activation recompute to use. Megatron-core supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. These memory intensive activations are also less compute intensive which makes activation @@ -197,7 +196,9 @@ class TransformerConfig(ModelParallelConfig): """Margin for the scaling factor computation.""" fp8_interval: int = 1 - """Controls how often the scaling factor is recomputed.""" + """DEPRECATED from TransformerEngine v1.8.0. This flag is ignored. + Controls how often the scaling factor is recomputed. + """ fp8_amax_history_len: int = 1 """The length of the amax history window used for scaling factor computation.""" diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 46f573a2b2..d7764bd907 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -679,7 +679,7 @@ def _add_transformer_engine_args(parser): help='Scaling margin for fp8', dest='fp8_margin') group.add_argument('--fp8-interval', type=int, default=1, - help='Scaling update interval for fp8', + help='DEPRECATED. This flag is ignored. Scaling update interval for fp8', dest='fp8_interval') group.add_argument('--fp8-amax-history-len', type=int, default=1, help='Number of steps for which amax history is recorded per tensor',