Skip to content

Commit

Permalink
[PyTorch] FP8 fixes (#380)
Browse files Browse the repository at this point in the history
* Initial refactor

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Reorder methods by purpose

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Save full global state

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix tests

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* More fixes to test

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
ksivaman authored Aug 16, 2023
1 parent cbfb8c6 commit 2e0bfbd
Show file tree
Hide file tree
Showing 7 changed files with 469 additions and 429 deletions.
1 change: 1 addition & 0 deletions qa/L0_lint/pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ extension-pkg-whitelist=torch,
transformer_engine_extensions

disable=too-many-locals,
too-many-public-methods,
invalid-name,
too-many-arguments,
abstract-method,
Expand Down
4 changes: 2 additions & 2 deletions tests/pytorch/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
scaled_init_method_normal,
get_device_compute_capability,
)
from transformer_engine.pytorch.fp8 import is_fp8_available
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch import TransformerLayer
from transformer_engine.pytorch.attention import DotProductAttention
import os

from pkg_resources import packaging
from importlib.metadata import version
from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states
fp8_available, reason_for_no_fp8 = is_fp8_available()
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")

Expand Down
4 changes: 2 additions & 2 deletions tests/pytorch/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import transformer_engine.pytorch.softmax as softmax_defs
from transformer_engine.pytorch.utils import get_default_init_method
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.fp8 import is_fp8_available
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager

# Global test configuration knobs.

Expand Down Expand Up @@ -66,7 +66,7 @@
# Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT).
ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so")

fp8_available, reason_for_no_fp8 = is_fp8_available()
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)

supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
Expand Down
4 changes: 2 additions & 2 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import pytest

from transformer_engine.pytorch.fp8 import fp8_autocast, is_fp8_available
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
Expand All @@ -21,7 +21,7 @@
from transformer_engine.common import recipe

# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = is_fp8_available()
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()


def custom_amax_to_scale(
Expand Down
5 changes: 3 additions & 2 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .utils import safely_set_viewless_tensor_data
from .constants import dist_group_type
from .fp8 import is_fp8_enabled
from .fp8 import FP8GlobalStateManager

_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
"tensor_model_parallel": False,
Expand Down Expand Up @@ -145,7 +145,8 @@ def activation_recompute_forward(
"""
global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
try:
_FP8_ACTIVATION_RECOMPUTE_ENABLED = activation_recompute and is_fp8_enabled()
_FP8_ACTIVATION_RECOMPUTE_ENABLED = (
activation_recompute and FP8GlobalStateManager.is_fp8_enabled())
_FP8_ACTIVATION_RECOMPUTE_PHASE = recompute_phase
yield
finally:
Expand Down
Loading

0 comments on commit 2e0bfbd

Please sign in to comment.