Skip to content

Commit

Permalink
Merge branch 'xiny/fix_te_versions' into 'main'
Browse files Browse the repository at this point in the history
Fix TE versions

See merge request ADLR/megatron-lm!1935
  • Loading branch information
ko3n1g committed Sep 4, 2024
2 parents 27289dc + 98b43c9 commit c2dc781
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
17 changes: 15 additions & 2 deletions megatron/core/transformer/custom_layers/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dataclasses
import os
import warnings
from importlib.metadata import version
from typing import Callable

Expand All @@ -26,6 +27,8 @@


def get_te_version():
"""Get TE version from __version__; if not available use pip's. Use caching."""

def get_te_version_str():
if hasattr(te, '__version__'):
return str(te.__version__)
Expand All @@ -50,6 +53,7 @@ def _get_extra_te_kwargs(config: TransformerConfig):


def condition_init_method(config, init_method):
"""Condition TE init_method on config.perform_initialization."""
return init_method if config.perform_initialization else (lambda w: None)


Expand Down Expand Up @@ -168,6 +172,7 @@ def __init__(
)

def forward(self, x):
"""Forward."""
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
Expand Down Expand Up @@ -287,6 +292,7 @@ def __init__(
)

def forward(self, x):
"""Forward."""
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
Expand Down Expand Up @@ -509,6 +515,7 @@ def forward(
attn_mask_type: AttnMaskType,
packed_seq_params: PackedSeqParams = None,
):
"""Forward."""
packed_seq_kwargs = (
dataclasses.asdict(packed_seq_params) if packed_seq_params is not None else {}
)
Expand Down Expand Up @@ -647,6 +654,7 @@ def __init__(
setattr(param, 'allreduce', not (is_expert and self.expert_parallel))

def forward(self, x, m_splits):
"""Forward."""
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
Expand Down Expand Up @@ -827,10 +835,13 @@ def __init__(
if _te_version >= packaging.version.Version("1.6.0.dev0"):
extra_kwargs["fp8_dpa"] = config.fp8_dot_product_attention
extra_kwargs["fp8_mha"] = config.fp8_multi_head_attention
if _te_version < packaging.version.Version("1.8.0"):
extra_kwargs["interval"] = config.fp8_interval
elif config.fp8_interval != 1:
warnings.warn("fp8_interval is deprecated and ignored from Transformer-Engine v1.8.0.")

super().__init__(
margin=config.fp8_margin,
interval=config.fp8_interval,
fp8_format=fp8_format,
amax_compute_algo=config.fp8_amax_compute_algo,
amax_history_len=config.fp8_amax_history_len,
Expand All @@ -850,6 +861,7 @@ def te_checkpoint(
context_mask,
rotary_pos_emb,
):
"""Checkpointing with Transformer-Engine."""
from transformer_engine.pytorch.distributed import checkpoint

if _te_version >= packaging.version.Version("1.5.0"):
Expand Down Expand Up @@ -897,7 +909,8 @@ def te_checkpoint(
def get_cpu_offload_context(
enabled, num_layers, model_layers, activation_offloading, weight_offloading
):
if _te_version > packaging.version.Version("1.8.0"):
"""Get CPU offload context and sync function."""
if _te_version >= packaging.version.Version("1.10.0.dev0"):
context, sync_func = _get_cpu_offload_context(
enabled, num_layers, model_layers, activation_offloading, weight_offloading
)
Expand Down
5 changes: 3 additions & 2 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ class TransformerConfig(ModelParallelConfig):
# activation recomputation
####################
recompute_granularity: str = None
recompute_granularity: str = None
"""Determines which type of activation recompute to use. Megatron-core supports 'selective'
activation checkpointing where only the memory intensive part of attention is checkpointed.
These memory intensive activations are also less compute intensive which makes activation
Expand Down Expand Up @@ -197,7 +196,9 @@ class TransformerConfig(ModelParallelConfig):
"""Margin for the scaling factor computation."""

fp8_interval: int = 1
"""Controls how often the scaling factor is recomputed."""
"""DEPRECATED from TransformerEngine v1.8.0. This flag is ignored.
Controls how often the scaling factor is recomputed.
"""

fp8_amax_history_len: int = 1
"""The length of the amax history window used for scaling factor computation."""
Expand Down
2 changes: 1 addition & 1 deletion megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ def _add_transformer_engine_args(parser):
help='Scaling margin for fp8',
dest='fp8_margin')
group.add_argument('--fp8-interval', type=int, default=1,
help='Scaling update interval for fp8',
help='DEPRECATED. This flag is ignored. Scaling update interval for fp8',
dest='fp8_interval')
group.add_argument('--fp8-amax-history-len', type=int, default=1,
help='Number of steps for which amax history is recorded per tensor',
Expand Down

0 comments on commit c2dc781

Please sign in to comment.