Skip to content

Commit

Permalink
use packaging.version
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Yao <[email protected]>
  • Loading branch information
yaox12 committed Nov 5, 2024
1 parent 64a506c commit 30f8cae
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 25 deletions.
2 changes: 1 addition & 1 deletion tests/pytorch/test_fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.jit import gpu_autocast_ctx
from transformer_engine.pytorch.utils import gpu_autocast_ctx

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
Expand Down
6 changes: 2 additions & 4 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules

from .utils import safely_set_viewless_tensor_data
from .utils import safely_set_viewless_tensor_data, is_torch_min_version
from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager
from .float8_tensor import Float8Tensor
Expand Down Expand Up @@ -252,9 +252,7 @@ def _get_active_autocast_contexts():
"""
autocast_cached = torch.is_autocast_cache_enabled()

TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 2 and TORCH_MINOR >= 4:
if is_torch_min_version("2.4.0a0"):
gpu_autocast_enabled = torch.is_autocast_enabled("cuda")
gpu_autocast_dtype = torch.get_autocast_dtype("cuda")
gpu_autocast_ctx = torch.amp.autocast(
Expand Down
22 changes: 8 additions & 14 deletions transformer_engine/pytorch/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,49 +5,43 @@
"""NVFuser functions and JIT utilities"""
import os
from typing import Callable, Optional, Tuple
from functools import partial

import torch

from .utils import is_torch_min_version, gpu_autocast_ctx

# pylint: disable=unnecessary-lambda-assignment

jit_fuser = torch.jit.script
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
if is_torch_min_version("2a0") and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = torch.compile

# See: https://github.com/NVIDIA/TransformerEngine/issues/597
dropout_fuser = torch.jit.script
if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
if is_torch_min_version("2.2a0") and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
dropout_fuser = torch.compile

# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda recursive=True: lambda func: func
if torch.__version__ >= "2":
if is_torch_min_version("2a0"):
import torch._dynamo

if torch.__version__ >= "2.1":
if is_torch_min_version("2.1a0"):
no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(
f, recursive=recursive
)
else:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable

if torch.__version__ >= "2.4":
gpu_autocast_ctx = partial(torch.amp.autocast, device_type="cuda")
else:
gpu_autocast_ctx = torch.cuda.amp.autocast


def set_jit_fusion_options() -> None:
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 2 and TORCH_MINOR >= 2:
if is_torch_min_version("2.2.0a0"):
pass
elif (TORCH_MAJOR == 2) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
elif is_torch_min_version("1.10.0a0"):
# nvfuser
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
Expand Down
5 changes: 2 additions & 3 deletions transformer_engine/pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
cast_if_needed,
get_default_init_method,
torch_get_autocast_gpu_dtype,
is_torch_min_version,
)
from transformer_engine.pytorch.constants import (
AttnMaskTypes,
Expand Down Expand Up @@ -431,9 +432,7 @@ def __init__(
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None

# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
use_nvfuser = is_torch_min_version("1.10.0a0") and not is_torch_min_version("2.2.0a0")
self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad

if self.bias_dropout_fusion:
Expand Down
20 changes: 17 additions & 3 deletions transformer_engine/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,21 @@
from __future__ import annotations
import functools
import math
from packaging.version import Version as PkgVersion
from typing import Any, Callable, Optional, Tuple

import torch
import transformer_engine.pytorch.cpp_extensions as ext

_torch_version = PkgVersion(torch.__version__)


def is_torch_min_version(version, check_equality=True):
"""Check if minimum version of `torch` is installed."""
if check_equality:
return _torch_version >= PkgVersion(version)
return _torch_version > PkgVersion(version)


def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
"""Check if any of the given tensors require gradient."""
Expand Down Expand Up @@ -309,8 +319,12 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool:

def torch_get_autocast_gpu_dtype() -> torch.dtype:
"""Get PyTorch autocast GPU dtype."""
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 2 and TORCH_MINOR >= 4:
if is_torch_min_version("2.4.0a0"):
return torch.get_autocast_dtype("cuda")
return torch.get_autocast_gpu_dtype()


if is_torch_min_version("2.4.0a0"):
gpu_autocast_ctx = functools.partial(torch.amp.autocast, device_type="cuda")
else:
gpu_autocast_ctx = torch.cuda.amp.autocast

0 comments on commit 30f8cae

Please sign in to comment.