Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Changes on top of upstream to get rid of type errors (#248)
Browse files Browse the repository at this point in the history
Summary:
Fixes the class of failed unit tests  on rocm in test_base.py that fail the internal assertion `Cannot convert ScalarType Float8_e4m3fn to hipDataType.`

Note: We are aware of the outstanding numerical issues and are looking into it internally.

Pull Request resolved: #248

Reviewed By: vkuzo

Differential Revision: D58652172

Pulled By: drisspg

fbshipit-source-id: b62845a8eb3773bd4de5396e8c47aef94cd7e600
  • Loading branch information
alugorey authored and facebook-github-bot committed Jun 20, 2024
1 parent edae9a3 commit 0bd374d
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 37 deletions.
4 changes: 4 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@
# implements pre/post-all-gather methods to do fp8 all-gather with FSDP2.
# Only dynamic scaling is supported for now.
enable_fsdp_fp8_all_gather = False

# If True, use 'fnuz' float8 types for calculations.
# Currently, ROCm only supports fnuz variants.
use_fnuz_dtype = False
12 changes: 5 additions & 7 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
tensor_already_casted_to_fp8,
to_fp8_no_autograd,
)
from float8_experimental.float8_utils import tensor_to_scale
from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_scale
from torch._prims_common import suggest_memory_format


Expand All @@ -46,9 +46,9 @@ def forward(
def backward(ctx, gradY):
if tensor_already_casted_to_fp8(gradY):
return gradY, None
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
fp8_tensor = to_fp8_no_autograd(
gradY, gradY_scale, torch.float8_e5m2, mm_config=ctx.mm_config
gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config
)
return fp8_tensor, None

Expand Down Expand Up @@ -105,10 +105,8 @@ def cast_to_float8_e4m3fn(
) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn, reduce_amax)
return Float8Tensor.to_float8(
inpt_tensor, scale, torch.float8_e4m3fn, mm_config=mm_config
)
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)


def cast_to_float8_e5m2_bw(
Expand Down
19 changes: 12 additions & 7 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
to_fp8_no_autograd,
)

from float8_experimental.float8_utils import amax_history_to_scale, tensor_to_amax
from float8_experimental.float8_utils import (
amax_history_to_scale,
e4m3_dtype,
e5m2_dtype,
tensor_to_amax,
)


def _maybe_initialize_amaxes_scales_for_float8_cast(
Expand Down Expand Up @@ -89,15 +94,15 @@ def backward(ctx, go):
fp8_amax_history_dL_dY,
fp8_scale_dL_dY,
scale_fn_name,
torch.float8_e5m2,
e5m2_dtype,
is_amax_initialized,
reduce_amax=True,
)

fp8_amax_dL_dY.fill_(tensor_to_amax(go))

res = to_fp8_no_autograd(
go, fp8_scale_dL_dY, torch.float8_e5m2, mm_config=ctx.mm_config
go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config
)
empty_grads = None, None, None, None, None, None
return res, *empty_grads
Expand Down Expand Up @@ -236,14 +241,14 @@ def cast_x_to_float8(
self.fp8_amax_history_x,
self.fp8_scale_x,
scale_fn_name,
torch.float8_e4m3fn,
e4m3_dtype,
is_amax_initialized,
reduce_amax=True,
)
x_fp8 = Float8Tensor.to_float8(
x,
self.fp8_scale_x,
torch.float8_e4m3fn,
e4m3_dtype,
self.fp8_amax_x,
self.forward_config,
)
Expand All @@ -259,15 +264,15 @@ def cast_w_to_float8(
self.fp8_amax_history_w,
self.fp8_scale_w,
scale_fn_name,
torch.float8_e4m3fn,
e4m3_dtype,
is_amax_initialized,
reduce_amax=False,
)

w_fp8 = Float8Tensor.to_float8(
w,
self.fp8_scale_w,
torch.float8_e4m3fn,
e4m3_dtype,
self.fp8_amax_w,
self.forward_config,
)
Expand Down
12 changes: 8 additions & 4 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear

from float8_experimental.float8_utils import amax_history_to_scale_stack
from float8_experimental.float8_utils import (
amax_history_to_scale_stack,
e4m3_dtype,
e5m2_dtype,
)
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -298,13 +302,13 @@ def inner_func():

# Calculate the new scales from the updated history stacks
new_x_scales = amax_history_to_scale_stack(
fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
fp8_x_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe
)
new_w_scales = amax_history_to_scale_stack(
fp8_w_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
fp8_w_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe
)
new_dL_dY_scales = amax_history_to_scale_stack(
fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe
fp8_dL_dY_amax_history_stack, e5m2_dtype, x_dtype, scale_fn_recipe
)

# Iterate through the layers and update the scales
Expand Down
8 changes: 6 additions & 2 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
import torch

import torch.distributed._functional_collectives as funcol
from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated
from float8_experimental.float8_utils import (
e4m3_dtype,
tensor_to_amax,
to_fp8_saturated,
)
from torch.distributed._tensor import DTensor

aten = torch.ops.aten
Expand Down Expand Up @@ -125,7 +129,7 @@ def forward(
ctx,
tensor: torch.Tensor,
scale: torch.Tensor,
float8_dtype=torch.float8_e4m3fn,
float8_dtype=e4m3_dtype,
amax_buffer: Optional[torch.Tensor] = None,
mm_config: Optional[ScaledMMConfig] = None,
):
Expand Down
11 changes: 9 additions & 2 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from typing import Literal, Tuple

import float8_experimental.config as config

import torch
import torch.distributed as dist

Expand All @@ -16,7 +18,7 @@
# TODO: align this value with NVIDIA's assumptions (current value is a guess)
EPS = 1e-12

IS_AMD = torch.cuda.is_available() and torch.version.hip is not None
IS_ROCM = torch.cuda.is_available() and torch.version.hip is not None
FP8_TYPES = {
torch.float8_e4m3fn,
torch.float8_e5m2,
Expand All @@ -25,6 +27,11 @@
}


# User defined type for using the individual F8 type based on config
e4m3_dtype = torch.float8_e4m3fn if not config.use_fnuz_dtype else torch.float8_e4m3fnuz
e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz


@torch.no_grad()
def amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
Expand Down Expand Up @@ -148,7 +155,7 @@ def compute_error(x: torch.Tensor, y: torch.Tensor):


def fp8_tensor_statistics(
tensor: torch.Tensor, float8_dtype=torch.float8_e4m3fn
tensor: torch.Tensor, float8_dtype=e4m3_dtype
) -> Tuple[int, ...]:
"""Calculate FP8 tensor stats
Expand Down
26 changes: 18 additions & 8 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
)
from float8_experimental.float8_utils import (
compute_error,
e4m3_dtype,
e5m2_dtype,
fp8_tensor_statistics,
FP8_TYPES,
tensor_to_scale,
Expand All @@ -51,7 +53,7 @@ class TestFloat8Tensor(unittest.TestCase):
def test_preserves_dtype(self) -> None:
# hp means high precision, lp means low precision
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2)
lp_dtypes = FP8_TYPES
for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes):
x1_hp = torch.randn(4, 4, dtype=hp_dtype)
x1_s = tensor_to_scale(x1_hp, lp_dtype)
Expand All @@ -60,7 +62,7 @@ def test_preserves_dtype(self) -> None:
self.assertTrue(x3_hp.dtype == hp_dtype)

def test_differentiable_casts(self) -> None:
lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2)
lp_dtypes = (e4m3_dtype, e5m2_dtype)
for f8_dtype in lp_dtypes:
x = torch.randn(1).requires_grad_()
grad = torch.randn(1)
Expand All @@ -73,8 +75,8 @@ def test_differentiable_casts(self) -> None:

def test_split_cat(self):
a = torch.rand(16, 16, dtype=torch.bfloat16)
scale = tensor_to_scale(a, torch.float8_e4m3fn)
fp8_a = Float8Tensor.to_float8(a, scale, torch.float8_e4m3fn)
scale = tensor_to_scale(a, e4m3_dtype)
fp8_a = Float8Tensor.to_float8(a, scale, e4m3_dtype)

splits = torch.split(fp8_a, 16)
catted = torch.cat(splits, dim=0)
Expand Down Expand Up @@ -313,7 +315,7 @@ class TestScaledMM:
@pytest.mark.parametrize("use_fast_accum", [True, False])
def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
torch.manual_seed(42)
input_dtype = torch.float8_e4m3fn
input_dtype = e4m3_dtype
output_dtype = base_dtype
compare_type = torch.float32

Expand Down Expand Up @@ -352,7 +354,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
def test_different_configs_error(self):
x_fp32 = torch.randn(16, 16, device="cuda")
x_scale = torch.tensor(1.0, device="cuda")
fp8_dtype = torch.float8_e4m3fn
fp8_dtype = e4m3_dtype
a = Float8Tensor.to_float8(x_fp32, x_scale, fp8_dtype)
b = Float8Tensor.to_float8(
x_fp32, x_scale, fp8_dtype, mm_config=ScaledMMConfig(True)
Expand Down Expand Up @@ -387,7 +389,15 @@ def test_merge_configs(self):


class TestNumerics:
@pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@pytest.mark.parametrize(
"float8_dtype",
[
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
],
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_small_amax_float16(self, float8_dtype):
# If we calculate scale naively with FP8_MAX_POS / amax,
Expand Down Expand Up @@ -508,7 +518,7 @@ def __init__(self, dim: int):

def test_fp8_tensor_statistics(self):
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2)
lp_dtypes = (e4m3_dtype, e5m2_dtype)
for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes):
x1_hp = torch.ones(4, 4, dtype=hp_dtype)
tensor_len = x1_hp.numel()
Expand Down
5 changes: 4 additions & 1 deletion test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
from float8_experimental.float8_utils import e4m3_dtype, IS_ROCM

from torch._dynamo.test_case import TestCase as DynamoTestCase
from torch._dynamo.testing import CompileCounterWithBackend
Expand Down Expand Up @@ -116,7 +117,7 @@ def forward(self, x):
x_fp8 = Float8Tensor.to_float8(
x,
self.fp8_scale_x,
torch.float8_e4m3fn,
e4m3_dtype,
self.fp8_amax_x,
ScaledMMConfig(),
)
Expand All @@ -127,12 +128,14 @@ def forward(self, x):
return x_fp8

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(IS_ROCM, "test doesn't currently work on the ROCm stack")
def test_float8_with_graph_break_in_the_middle(self):
"""Test that having Float8Tensor object at the boundary of a subgraph"""
cnts = CompileCounterWithBackend("inductor")
mod = self.MockLinear(graph_break=True).cuda()
compiled_mod = copy.deepcopy(mod)
compiled_mod = torch.compile(compiled_mod, backend=cnts)
torch.manual_seed(0)
x = torch.randn(16, 16, device="cuda")
y_eager = mod(x)
y_compiled = compiled_mod(x)
Expand Down
10 changes: 5 additions & 5 deletions test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)
from float8_experimental.float8_utils import tensor_to_scale
from float8_experimental.float8_utils import e4m3_dtype, tensor_to_scale
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor.parallel import parallelize_module
Expand Down Expand Up @@ -64,7 +64,7 @@ def forward(self, x):

def test_scaled_mm(mesh: DeviceMesh, size=16):
device = mesh.device_type
fp8_dtype = torch.float8_e4m3fn
fp8_dtype = e4m3_dtype
world_size = mesh.size()

x_fp32 = torch.rand(size, size, device=device)
Expand Down Expand Up @@ -103,7 +103,7 @@ def test_scaled_mm(mesh: DeviceMesh, size=16):

def test_fp8_redistribute(mesh: DeviceMesh, size=16):
device = mesh.device_type
fp8_dtype = torch.float8_e4m3fn
fp8_dtype = e4m3_dtype
world_size = mesh.size()

x_fp32 = torch.rand(size, size, device=device)
Expand All @@ -130,7 +130,7 @@ def test_fp8_redistribute(mesh: DeviceMesh, size=16):

def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16):
device = mesh.device_type
fp8_dtype = torch.float8_e4m3fn
fp8_dtype = e4m3_dtype

x_fp32 = torch.rand(size, size, device=device)
dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)])
Expand All @@ -144,7 +144,7 @@ def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16):

def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
device = mesh.device_type
fp8_dtype = torch.float8_e4m3fn
fp8_dtype = e4m3_dtype

x_fp32 = torch.rand(size, size, device=device, requires_grad=True)
local_weight = torch.rand(2 * size, size, device=device, requires_grad=True)
Expand Down
6 changes: 6 additions & 0 deletions test/test_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@

# terminate script on first error
set -e
IS_ROCM=$(rocm-smi --version || true)

pytest test/test_base.py
pytest test/test_sam.py
pytest test/test_compile.py

# These tests do not work on ROCm yet
if [ -z "$IS_ROCM" ]
then
./test/test_fsdp.sh
./test/test_fsdp_compile.sh
./test/test_dtensor.sh
pytest test/test_fsdp2/test_fsdp2_eager.py
fi

echo "all tests successful"
3 changes: 2 additions & 1 deletion test/test_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_utils import compute_error
from float8_experimental.float8_utils import compute_error, IS_ROCM
from transformers import SamModel

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
Expand All @@ -31,6 +31,7 @@ class TestFloat8SAMIntegrationTest:
@pytest.mark.parametrize("data_dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("linear_type", [Float8Linear, Float8DynamicLinear])
@pytest.mark.skipif(not is_H100, reason="requires H100 GPU")
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
def test_encoder_fw_bw(self, data_dtype, linear_type):
model = SamModel.from_pretrained("facebook/sam-vit-base").to(data_dtype).cuda()
# print(model)
Expand Down

0 comments on commit 0bd374d

Please sign in to comment.