diff --git a/megatron/core/transformer/custom_layers/transformer_engine.py b/megatron/core/transformer/custom_layers/transformer_engine.py index ef7e498eab..33b67231e1 100644 --- a/megatron/core/transformer/custom_layers/transformer_engine.py +++ b/megatron/core/transformer/custom_layers/transformer_engine.py @@ -2,6 +2,7 @@ import dataclasses import os +import warnings from importlib.metadata import version from typing import Callable @@ -26,6 +27,8 @@ def get_te_version(): + """Get TE version from __version__; if not available use pip's. Use caching.""" + def get_te_version_str(): if hasattr(te, '__version__'): return str(te.__version__) @@ -50,6 +53,7 @@ def _get_extra_te_kwargs(config: TransformerConfig): def condition_init_method(config, init_method): + """Condition TE init_method on config.perform_initialization.""" return init_method if config.perform_initialization else (lambda w: None) @@ -168,6 +172,7 @@ def __init__( ) def forward(self, x): + """Forward.""" _is_first_microbatch = ( None if self.disable_parameter_transpose_cache else self.is_first_microbatch ) @@ -287,6 +292,7 @@ def __init__( ) def forward(self, x): + """Forward.""" _is_first_microbatch = ( None if self.disable_parameter_transpose_cache else self.is_first_microbatch ) @@ -509,6 +515,7 @@ def forward( attn_mask_type: AttnMaskType, packed_seq_params: PackedSeqParams = None, ): + """Forward.""" packed_seq_kwargs = ( dataclasses.asdict(packed_seq_params) if packed_seq_params is not None else {} ) @@ -647,6 +654,7 @@ def __init__( setattr(param, 'allreduce', not (is_expert and self.expert_parallel)) def forward(self, x, m_splits): + """Forward.""" _is_first_microbatch = ( None if self.disable_parameter_transpose_cache else self.is_first_microbatch ) @@ -827,10 +835,13 @@ def __init__( if _te_version >= packaging.version.Version("1.6.0.dev0"): extra_kwargs["fp8_dpa"] = config.fp8_dot_product_attention extra_kwargs["fp8_mha"] = config.fp8_multi_head_attention + if _te_version < packaging.version.Version("1.8.0"): + extra_kwargs["interval"] = config.fp8_interval + elif config.fp8_interval != 1: + warnings.warn("fp8_interval is deprecated and ignored from Transformer-Engine v1.8.0.") super().__init__( margin=config.fp8_margin, - interval=config.fp8_interval, fp8_format=fp8_format, amax_compute_algo=config.fp8_amax_compute_algo, amax_history_len=config.fp8_amax_history_len, @@ -850,6 +861,7 @@ def te_checkpoint( context_mask, rotary_pos_emb, ): + """Checkpointing with Transformer-Engine.""" from transformer_engine.pytorch.distributed import checkpoint if _te_version >= packaging.version.Version("1.5.0"): @@ -897,7 +909,8 @@ def te_checkpoint( def get_cpu_offload_context( enabled, num_layers, model_layers, activation_offloading, weight_offloading ): - if _te_version > packaging.version.Version("1.8.0"): + """Get CPU offload context and sync function.""" + if _te_version >= packaging.version.Version("1.10.0.dev0"): context, sync_func = _get_cpu_offload_context( enabled, num_layers, model_layers, activation_offloading, weight_offloading ) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 84626159c3..00c83ddbbb 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -158,7 +158,6 @@ class TransformerConfig(ModelParallelConfig): # activation recomputation #################### recompute_granularity: str = None - recompute_granularity: str = None """Determines which type of activation recompute to use. Megatron-core supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. These memory intensive activations are also less compute intensive which makes activation @@ -197,7 +196,9 @@ class TransformerConfig(ModelParallelConfig): """Margin for the scaling factor computation.""" fp8_interval: int = 1 - """Controls how often the scaling factor is recomputed.""" + """DEPRECATED from TransformerEngine v1.8.0. This flag is ignored. + Controls how often the scaling factor is recomputed. + """ fp8_amax_history_len: int = 1 """The length of the amax history window used for scaling factor computation.""" diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index c39c19b498..5ec39501c9 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -692,7 +692,7 @@ def _add_transformer_engine_args(parser): help='Scaling margin for fp8', dest='fp8_margin') group.add_argument('--fp8-interval', type=int, default=1, - help='Scaling update interval for fp8', + help='DEPRECATED. This flag is ignored. Scaling update interval for fp8', dest='fp8_interval') group.add_argument('--fp8-amax-history-len', type=int, default=1, help='Number of steps for which amax history is recorded per tensor',