From 0bd374debe79548e9133d0b41426c7775e5f8f83 Mon Sep 17 00:00:00 2001 From: Andres Lugo Date: Thu, 20 Jun 2024 12:33:29 -0700 Subject: [PATCH] Changes on top of upstream to get rid of type errors (#248) Summary: Fixes the class of failed unit tests on rocm in test_base.py that fail the internal assertion `Cannot convert ScalarType Float8_e4m3fn to hipDataType.` Note: We are aware of the outstanding numerical issues and are looking into it internally. Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/248 Reviewed By: vkuzo Differential Revision: D58652172 Pulled By: drisspg fbshipit-source-id: b62845a8eb3773bd4de5396e8c47aef94cd7e600 --- float8_experimental/config.py | 4 +++ float8_experimental/float8_dynamic_linear.py | 12 ++++----- float8_experimental/float8_linear.py | 19 ++++++++------ float8_experimental/float8_linear_utils.py | 12 ++++++--- float8_experimental/float8_tensor.py | 8 ++++-- float8_experimental/float8_utils.py | 11 +++++++-- test/test_base.py | 26 ++++++++++++++------ test/test_compile.py | 5 +++- test/test_dtensor.py | 10 ++++---- test/test_everything.sh | 6 +++++ test/test_sam.py | 3 ++- 11 files changed, 79 insertions(+), 37 deletions(-) diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 4956428c..41b278c4 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -19,3 +19,7 @@ # implements pre/post-all-gather methods to do fp8 all-gather with FSDP2. # Only dynamic scaling is supported for now. enable_fsdp_fp8_all_gather = False + +# If True, use 'fnuz' float8 types for calculations. +# Currently, ROCm only supports fnuz variants. +use_fnuz_dtype = False diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 701ae8a1..0d4dbc04 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -22,7 +22,7 @@ tensor_already_casted_to_fp8, to_fp8_no_autograd, ) -from float8_experimental.float8_utils import tensor_to_scale +from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_scale from torch._prims_common import suggest_memory_format @@ -46,9 +46,9 @@ def forward( def backward(ctx, gradY): if tensor_already_casted_to_fp8(gradY): return gradY, None - gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2) + gradY_scale = tensor_to_scale(gradY, e5m2_dtype) fp8_tensor = to_fp8_no_autograd( - gradY, gradY_scale, torch.float8_e5m2, mm_config=ctx.mm_config + gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config ) return fp8_tensor, None @@ -105,10 +105,8 @@ def cast_to_float8_e4m3fn( ) -> Float8Tensor: if tensor_already_casted_to_fp8(inpt_tensor): return inpt_tensor - scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn, reduce_amax) - return Float8Tensor.to_float8( - inpt_tensor, scale, torch.float8_e4m3fn, mm_config=mm_config - ) + scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) + return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config) def cast_to_float8_e5m2_bw( diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 9e1fdf8c..35c03c0e 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -21,7 +21,12 @@ to_fp8_no_autograd, ) -from float8_experimental.float8_utils import amax_history_to_scale, tensor_to_amax +from float8_experimental.float8_utils import ( + amax_history_to_scale, + e4m3_dtype, + e5m2_dtype, + tensor_to_amax, +) def _maybe_initialize_amaxes_scales_for_float8_cast( @@ -89,7 +94,7 @@ def backward(ctx, go): fp8_amax_history_dL_dY, fp8_scale_dL_dY, scale_fn_name, - torch.float8_e5m2, + e5m2_dtype, is_amax_initialized, reduce_amax=True, ) @@ -97,7 +102,7 @@ def backward(ctx, go): fp8_amax_dL_dY.fill_(tensor_to_amax(go)) res = to_fp8_no_autograd( - go, fp8_scale_dL_dY, torch.float8_e5m2, mm_config=ctx.mm_config + go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config ) empty_grads = None, None, None, None, None, None return res, *empty_grads @@ -236,14 +241,14 @@ def cast_x_to_float8( self.fp8_amax_history_x, self.fp8_scale_x, scale_fn_name, - torch.float8_e4m3fn, + e4m3_dtype, is_amax_initialized, reduce_amax=True, ) x_fp8 = Float8Tensor.to_float8( x, self.fp8_scale_x, - torch.float8_e4m3fn, + e4m3_dtype, self.fp8_amax_x, self.forward_config, ) @@ -259,7 +264,7 @@ def cast_w_to_float8( self.fp8_amax_history_w, self.fp8_scale_w, scale_fn_name, - torch.float8_e4m3fn, + e4m3_dtype, is_amax_initialized, reduce_amax=False, ) @@ -267,7 +272,7 @@ def cast_w_to_float8( w_fp8 = Float8Tensor.to_float8( w, self.fp8_scale_w, - torch.float8_e4m3fn, + e4m3_dtype, self.fp8_amax_w, self.forward_config, ) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 4d850459..92392006 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -14,7 +14,11 @@ from float8_experimental.float8_dynamic_linear import Float8DynamicLinear from float8_experimental.float8_linear import Float8Linear -from float8_experimental.float8_utils import amax_history_to_scale_stack +from float8_experimental.float8_utils import ( + amax_history_to_scale_stack, + e4m3_dtype, + e5m2_dtype, +) from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor log = logging.getLogger(__name__) @@ -298,13 +302,13 @@ def inner_func(): # Calculate the new scales from the updated history stacks new_x_scales = amax_history_to_scale_stack( - fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe + fp8_x_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe ) new_w_scales = amax_history_to_scale_stack( - fp8_w_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe + fp8_w_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe ) new_dL_dY_scales = amax_history_to_scale_stack( - fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe + fp8_dL_dY_amax_history_stack, e5m2_dtype, x_dtype, scale_fn_recipe ) # Iterate through the layers and update the scales diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 2535b69c..5c8e9a8c 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -9,7 +9,11 @@ import torch import torch.distributed._functional_collectives as funcol -from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated +from float8_experimental.float8_utils import ( + e4m3_dtype, + tensor_to_amax, + to_fp8_saturated, +) from torch.distributed._tensor import DTensor aten = torch.ops.aten @@ -125,7 +129,7 @@ def forward( ctx, tensor: torch.Tensor, scale: torch.Tensor, - float8_dtype=torch.float8_e4m3fn, + float8_dtype=e4m3_dtype, amax_buffer: Optional[torch.Tensor] = None, mm_config: Optional[ScaledMMConfig] = None, ): diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 8d898ea9..f9ae70a0 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -6,6 +6,8 @@ from typing import Literal, Tuple +import float8_experimental.config as config + import torch import torch.distributed as dist @@ -16,7 +18,7 @@ # TODO: align this value with NVIDIA's assumptions (current value is a guess) EPS = 1e-12 -IS_AMD = torch.cuda.is_available() and torch.version.hip is not None +IS_ROCM = torch.cuda.is_available() and torch.version.hip is not None FP8_TYPES = { torch.float8_e4m3fn, torch.float8_e5m2, @@ -25,6 +27,11 @@ } +# User defined type for using the individual F8 type based on config +e4m3_dtype = torch.float8_e4m3fn if not config.use_fnuz_dtype else torch.float8_e4m3fnuz +e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz + + @torch.no_grad() def amax_to_scale( amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype @@ -148,7 +155,7 @@ def compute_error(x: torch.Tensor, y: torch.Tensor): def fp8_tensor_statistics( - tensor: torch.Tensor, float8_dtype=torch.float8_e4m3fn + tensor: torch.Tensor, float8_dtype=e4m3_dtype ) -> Tuple[int, ...]: """Calculate FP8 tensor stats diff --git a/test/test_base.py b/test/test_base.py index 371e044f..da9da87f 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -30,6 +30,8 @@ ) from float8_experimental.float8_utils import ( compute_error, + e4m3_dtype, + e5m2_dtype, fp8_tensor_statistics, FP8_TYPES, tensor_to_scale, @@ -51,7 +53,7 @@ class TestFloat8Tensor(unittest.TestCase): def test_preserves_dtype(self) -> None: # hp means high precision, lp means low precision hp_dtypes = (torch.float32, torch.float16, torch.bfloat16) - lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2) + lp_dtypes = FP8_TYPES for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes): x1_hp = torch.randn(4, 4, dtype=hp_dtype) x1_s = tensor_to_scale(x1_hp, lp_dtype) @@ -60,7 +62,7 @@ def test_preserves_dtype(self) -> None: self.assertTrue(x3_hp.dtype == hp_dtype) def test_differentiable_casts(self) -> None: - lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2) + lp_dtypes = (e4m3_dtype, e5m2_dtype) for f8_dtype in lp_dtypes: x = torch.randn(1).requires_grad_() grad = torch.randn(1) @@ -73,8 +75,8 @@ def test_differentiable_casts(self) -> None: def test_split_cat(self): a = torch.rand(16, 16, dtype=torch.bfloat16) - scale = tensor_to_scale(a, torch.float8_e4m3fn) - fp8_a = Float8Tensor.to_float8(a, scale, torch.float8_e4m3fn) + scale = tensor_to_scale(a, e4m3_dtype) + fp8_a = Float8Tensor.to_float8(a, scale, e4m3_dtype) splits = torch.split(fp8_a, 16) catted = torch.cat(splits, dim=0) @@ -313,7 +315,7 @@ class TestScaledMM: @pytest.mark.parametrize("use_fast_accum", [True, False]) def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): torch.manual_seed(42) - input_dtype = torch.float8_e4m3fn + input_dtype = e4m3_dtype output_dtype = base_dtype compare_type = torch.float32 @@ -352,7 +354,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): def test_different_configs_error(self): x_fp32 = torch.randn(16, 16, device="cuda") x_scale = torch.tensor(1.0, device="cuda") - fp8_dtype = torch.float8_e4m3fn + fp8_dtype = e4m3_dtype a = Float8Tensor.to_float8(x_fp32, x_scale, fp8_dtype) b = Float8Tensor.to_float8( x_fp32, x_scale, fp8_dtype, mm_config=ScaledMMConfig(True) @@ -387,7 +389,15 @@ def test_merge_configs(self): class TestNumerics: - @pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + @pytest.mark.parametrize( + "float8_dtype", + [ + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, + ], + ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_small_amax_float16(self, float8_dtype): # If we calculate scale naively with FP8_MAX_POS / amax, @@ -508,7 +518,7 @@ def __init__(self, dim: int): def test_fp8_tensor_statistics(self): hp_dtypes = (torch.float32, torch.float16, torch.bfloat16) - lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2) + lp_dtypes = (e4m3_dtype, e5m2_dtype) for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes): x1_hp = torch.ones(4, 4, dtype=hp_dtype) tensor_len = x1_hp.numel() diff --git a/test/test_compile.py b/test/test_compile.py index 9cc64d32..34e97538 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -22,6 +22,7 @@ sync_float8_amax_and_scale_history, ) from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig +from float8_experimental.float8_utils import e4m3_dtype, IS_ROCM from torch._dynamo.test_case import TestCase as DynamoTestCase from torch._dynamo.testing import CompileCounterWithBackend @@ -116,7 +117,7 @@ def forward(self, x): x_fp8 = Float8Tensor.to_float8( x, self.fp8_scale_x, - torch.float8_e4m3fn, + e4m3_dtype, self.fp8_amax_x, ScaledMMConfig(), ) @@ -127,12 +128,14 @@ def forward(self, x): return x_fp8 @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf(IS_ROCM, "test doesn't currently work on the ROCm stack") def test_float8_with_graph_break_in_the_middle(self): """Test that having Float8Tensor object at the boundary of a subgraph""" cnts = CompileCounterWithBackend("inductor") mod = self.MockLinear(graph_break=True).cuda() compiled_mod = copy.deepcopy(mod) compiled_mod = torch.compile(compiled_mod, backend=cnts) + torch.manual_seed(0) x = torch.randn(16, 16, device="cuda") y_eager = mod(x) y_compiled = compiled_mod(x) diff --git a/test/test_dtensor.py b/test/test_dtensor.py index e3196085..354f8316 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -25,7 +25,7 @@ Float8RowwiseParallel, PrepareFloat8ModuleInput, ) -from float8_experimental.float8_utils import tensor_to_scale +from float8_experimental.float8_utils import e4m3_dtype, tensor_to_scale from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.tensor.parallel import parallelize_module @@ -64,7 +64,7 @@ def forward(self, x): def test_scaled_mm(mesh: DeviceMesh, size=16): device = mesh.device_type - fp8_dtype = torch.float8_e4m3fn + fp8_dtype = e4m3_dtype world_size = mesh.size() x_fp32 = torch.rand(size, size, device=device) @@ -103,7 +103,7 @@ def test_scaled_mm(mesh: DeviceMesh, size=16): def test_fp8_redistribute(mesh: DeviceMesh, size=16): device = mesh.device_type - fp8_dtype = torch.float8_e4m3fn + fp8_dtype = e4m3_dtype world_size = mesh.size() x_fp32 = torch.rand(size, size, device=device) @@ -130,7 +130,7 @@ def test_fp8_redistribute(mesh: DeviceMesh, size=16): def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16): device = mesh.device_type - fp8_dtype = torch.float8_e4m3fn + fp8_dtype = e4m3_dtype x_fp32 = torch.rand(size, size, device=device) dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)]) @@ -144,7 +144,7 @@ def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16): def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): device = mesh.device_type - fp8_dtype = torch.float8_e4m3fn + fp8_dtype = e4m3_dtype x_fp32 = torch.rand(size, size, device=device, requires_grad=True) local_weight = torch.rand(2 * size, size, device=device, requires_grad=True) diff --git a/test/test_everything.sh b/test/test_everything.sh index c7817f9c..b9893933 100755 --- a/test/test_everything.sh +++ b/test/test_everything.sh @@ -2,13 +2,19 @@ # terminate script on first error set -e +IS_ROCM=$(rocm-smi --version || true) pytest test/test_base.py pytest test/test_sam.py pytest test/test_compile.py + +# These tests do not work on ROCm yet +if [ -z "$IS_ROCM" ] +then ./test/test_fsdp.sh ./test/test_fsdp_compile.sh ./test/test_dtensor.sh pytest test/test_fsdp2/test_fsdp2_eager.py +fi echo "all tests successful" diff --git a/test/test_sam.py b/test/test_sam.py index d8eb7195..9341241e 100644 --- a/test/test_sam.py +++ b/test/test_sam.py @@ -18,7 +18,7 @@ swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, ) -from float8_experimental.float8_utils import compute_error +from float8_experimental.float8_utils import compute_error, IS_ROCM from transformers import SamModel is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) @@ -31,6 +31,7 @@ class TestFloat8SAMIntegrationTest: @pytest.mark.parametrize("data_dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("linear_type", [Float8Linear, Float8DynamicLinear]) @pytest.mark.skipif(not is_H100, reason="requires H100 GPU") + @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw(self, data_dtype, linear_type): model = SamModel.from_pretrained("facebook/sam-vit-base").to(data_dtype).cuda() # print(model)