diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 4956428..41b278c 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -19,3 +19,7 @@ # implements pre/post-all-gather methods to do fp8 all-gather with FSDP2. # Only dynamic scaling is supported for now. enable_fsdp_fp8_all_gather = False + +# If True, use 'fnuz' float8 types for calculations. +# Currently, ROCm only supports fnuz variants. +use_fnuz_dtype = False diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 701ae8a..5d1aa44 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -22,7 +22,7 @@ tensor_already_casted_to_fp8, to_fp8_no_autograd, ) -from float8_experimental.float8_utils import tensor_to_scale +from float8_experimental.float8_utils import tensor_to_scale, e4m3_dtype, e5m2_dtype from torch._prims_common import suggest_memory_format @@ -46,9 +46,9 @@ def forward( def backward(ctx, gradY): if tensor_already_casted_to_fp8(gradY): return gradY, None - gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2) + gradY_scale = tensor_to_scale(gradY, e5m2_dtype) fp8_tensor = to_fp8_no_autograd( - gradY, gradY_scale, torch.float8_e5m2, mm_config=ctx.mm_config + gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config ) return fp8_tensor, None @@ -105,9 +105,9 @@ def cast_to_float8_e4m3fn( ) -> Float8Tensor: if tensor_already_casted_to_fp8(inpt_tensor): return inpt_tensor - scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn, reduce_amax) + scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) return Float8Tensor.to_float8( - inpt_tensor, scale, torch.float8_e4m3fn, mm_config=mm_config + inpt_tensor, scale, e4m3_dtype, mm_config=mm_config ) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 9e1fdf8..a413925 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -21,7 +21,7 @@ to_fp8_no_autograd, ) -from float8_experimental.float8_utils import amax_history_to_scale, tensor_to_amax +from float8_experimental.float8_utils import amax_history_to_scale, tensor_to_amax, e4m3_dtype, e5m2_dtype def _maybe_initialize_amaxes_scales_for_float8_cast( @@ -89,7 +89,7 @@ def backward(ctx, go): fp8_amax_history_dL_dY, fp8_scale_dL_dY, scale_fn_name, - torch.float8_e5m2, + e5m2_dtype, is_amax_initialized, reduce_amax=True, ) @@ -97,7 +97,7 @@ def backward(ctx, go): fp8_amax_dL_dY.fill_(tensor_to_amax(go)) res = to_fp8_no_autograd( - go, fp8_scale_dL_dY, torch.float8_e5m2, mm_config=ctx.mm_config + go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config ) empty_grads = None, None, None, None, None, None return res, *empty_grads @@ -236,14 +236,14 @@ def cast_x_to_float8( self.fp8_amax_history_x, self.fp8_scale_x, scale_fn_name, - torch.float8_e4m3fn, + e4m3_dtype, is_amax_initialized, reduce_amax=True, ) x_fp8 = Float8Tensor.to_float8( x, self.fp8_scale_x, - torch.float8_e4m3fn, + e4m3_dtype, self.fp8_amax_x, self.forward_config, ) @@ -259,7 +259,7 @@ def cast_w_to_float8( self.fp8_amax_history_w, self.fp8_scale_w, scale_fn_name, - torch.float8_e4m3fn, + e4m3_dtype, is_amax_initialized, reduce_amax=False, ) @@ -267,7 +267,7 @@ def cast_w_to_float8( w_fp8 = Float8Tensor.to_float8( w, self.fp8_scale_w, - torch.float8_e4m3fn, + e4m3_dtype, self.fp8_amax_w, self.forward_config, ) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 4d85045..ba2b597 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -14,7 +14,7 @@ from float8_experimental.float8_dynamic_linear import Float8DynamicLinear from float8_experimental.float8_linear import Float8Linear -from float8_experimental.float8_utils import amax_history_to_scale_stack +from float8_experimental.float8_utils import amax_history_to_scale_stack, e4m3_dtype, e5m2_dtype from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor log = logging.getLogger(__name__) @@ -298,13 +298,13 @@ def inner_func(): # Calculate the new scales from the updated history stacks new_x_scales = amax_history_to_scale_stack( - fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe + fp8_x_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe ) new_w_scales = amax_history_to_scale_stack( - fp8_w_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe + fp8_w_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe ) new_dL_dY_scales = amax_history_to_scale_stack( - fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe + fp8_dL_dY_amax_history_stack, e5m2_dtype, x_dtype, scale_fn_recipe ) # Iterate through the layers and update the scales diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 2535b69..43c576a 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -9,7 +9,7 @@ import torch import torch.distributed._functional_collectives as funcol -from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated +from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated, e4m3_dtype from torch.distributed._tensor import DTensor aten = torch.ops.aten @@ -125,7 +125,7 @@ def forward( ctx, tensor: torch.Tensor, scale: torch.Tensor, - float8_dtype=torch.float8_e4m3fn, + float8_dtype=e4m3_dtype, amax_buffer: Optional[torch.Tensor] = None, mm_config: Optional[ScaledMMConfig] = None, ): diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 8d898ea..2200f67 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -9,6 +9,8 @@ import torch import torch.distributed as dist +import float8_experimental.config as config + # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html @@ -16,7 +18,7 @@ # TODO: align this value with NVIDIA's assumptions (current value is a guess) EPS = 1e-12 -IS_AMD = torch.cuda.is_available() and torch.version.hip is not None +IS_ROCM = torch.cuda.is_available() and torch.version.hip is not None FP8_TYPES = { torch.float8_e4m3fn, torch.float8_e5m2, @@ -25,6 +27,11 @@ } +# User defined type for using the individual F8 type based on config +e4m3_dtype = torch.float8_e4m3fn if not config.use_fnuz_dtype else torch.float8_e4m3fnuz +e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz + + @torch.no_grad() def amax_to_scale( amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype @@ -148,7 +155,7 @@ def compute_error(x: torch.Tensor, y: torch.Tensor): def fp8_tensor_statistics( - tensor: torch.Tensor, float8_dtype=torch.float8_e4m3fn + tensor: torch.Tensor, float8_dtype=e4m3_dtype ) -> Tuple[int, ...]: """Calculate FP8 tensor stats diff --git a/test/test_base.py b/test/test_base.py index 6e7a34c..c015e1d 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -34,6 +34,8 @@ fp8_tensor_statistics, FP8_TYPES, tensor_to_scale, + e4m3_dtype, + e5m2_dtype, ) random.seed(0) @@ -52,7 +54,7 @@ class TestFloat8Tensor(unittest.TestCase): def test_preserves_dtype(self) -> None: # hp means high precision, lp means low precision hp_dtypes = (torch.float32, torch.float16, torch.bfloat16) - lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2) + lp_dtypes = FP8_TYPES for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes): x1_hp = torch.randn(4, 4, dtype=hp_dtype) x1_s = tensor_to_scale(x1_hp, lp_dtype) @@ -61,7 +63,7 @@ def test_preserves_dtype(self) -> None: self.assertTrue(x3_hp.dtype == hp_dtype) def test_differentiable_casts(self) -> None: - lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2) + lp_dtypes = (e4m3_dtype, e5m2_dtype) for f8_dtype in lp_dtypes: x = torch.randn(1).requires_grad_() grad = torch.randn(1) @@ -74,8 +76,8 @@ def test_differentiable_casts(self) -> None: def test_split_cat(self): a = torch.rand(16, 16, dtype=torch.bfloat16) - scale = tensor_to_scale(a, torch.float8_e4m3fn) - fp8_a = Float8Tensor.to_float8(a, scale, torch.float8_e4m3fn) + scale = tensor_to_scale(a, e4m3_dtype) + fp8_a = Float8Tensor.to_float8(a, scale, e4m3_dtype) splits = torch.split(fp8_a, 16) catted = torch.cat(splits, dim=0) @@ -314,7 +316,7 @@ class TestScaledMM: @pytest.mark.parametrize("use_fast_accum", [True, False]) def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): torch.manual_seed(42) - input_dtype = torch.float8_e4m3fn + input_dtype = e4m3_dtype output_dtype = base_dtype compare_type = torch.float32 @@ -360,7 +362,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): def test_different_configs_error(self): x_fp32 = torch.randn(16, 16, device="cuda") x_scale = torch.tensor(1.0, device="cuda") - fp8_dtype = torch.float8_e4m3fn + fp8_dtype = e4m3_dtype a = Float8Tensor.to_float8(x_fp32, x_scale, fp8_dtype) b = Float8Tensor.to_float8( x_fp32, x_scale, fp8_dtype, mm_config=ScaledMMConfig(True) @@ -395,7 +397,10 @@ def test_merge_configs(self): class TestNumerics: - @pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + @pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_small_amax_float16(self, float8_dtype): # If we calculate scale naively with FP8_MAX_POS / amax, @@ -516,7 +521,7 @@ def __init__(self, dim: int): def test_fp8_tensor_statistics(self): hp_dtypes = (torch.float32, torch.float16, torch.bfloat16) - lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2) + lp_dtypes = (e4m3_dtype, e5m2_dtype) for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes): x1_hp = torch.ones(4, 4, dtype=hp_dtype) tensor_len = x1_hp.numel() diff --git a/test/test_compile.py b/test/test_compile.py index 9cc64d3..34e9753 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -22,6 +22,7 @@ sync_float8_amax_and_scale_history, ) from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig +from float8_experimental.float8_utils import e4m3_dtype, IS_ROCM from torch._dynamo.test_case import TestCase as DynamoTestCase from torch._dynamo.testing import CompileCounterWithBackend @@ -116,7 +117,7 @@ def forward(self, x): x_fp8 = Float8Tensor.to_float8( x, self.fp8_scale_x, - torch.float8_e4m3fn, + e4m3_dtype, self.fp8_amax_x, ScaledMMConfig(), ) @@ -127,12 +128,14 @@ def forward(self, x): return x_fp8 @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf(IS_ROCM, "test doesn't currently work on the ROCm stack") def test_float8_with_graph_break_in_the_middle(self): """Test that having Float8Tensor object at the boundary of a subgraph""" cnts = CompileCounterWithBackend("inductor") mod = self.MockLinear(graph_break=True).cuda() compiled_mod = copy.deepcopy(mod) compiled_mod = torch.compile(compiled_mod, backend=cnts) + torch.manual_seed(0) x = torch.randn(16, 16, device="cuda") y_eager = mod(x) y_compiled = compiled_mod(x) diff --git a/test/test_dtensor.py b/test/test_dtensor.py index e319608..5d138ce 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -25,7 +25,7 @@ Float8RowwiseParallel, PrepareFloat8ModuleInput, ) -from float8_experimental.float8_utils import tensor_to_scale +from float8_experimental.float8_utils import tensor_to_scale, e4m3_dtype from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.tensor.parallel import parallelize_module @@ -64,7 +64,7 @@ def forward(self, x): def test_scaled_mm(mesh: DeviceMesh, size=16): device = mesh.device_type - fp8_dtype = torch.float8_e4m3fn + fp8_dtype = e4m3_dtype world_size = mesh.size() x_fp32 = torch.rand(size, size, device=device) @@ -103,7 +103,7 @@ def test_scaled_mm(mesh: DeviceMesh, size=16): def test_fp8_redistribute(mesh: DeviceMesh, size=16): device = mesh.device_type - fp8_dtype = torch.float8_e4m3fn + fp8_dtype = e4m3_dtype world_size = mesh.size() x_fp32 = torch.rand(size, size, device=device) @@ -130,7 +130,7 @@ def test_fp8_redistribute(mesh: DeviceMesh, size=16): def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16): device = mesh.device_type - fp8_dtype = torch.float8_e4m3fn + fp8_dtype = e4m3_dtype x_fp32 = torch.rand(size, size, device=device) dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)]) @@ -144,7 +144,7 @@ def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16): def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): device = mesh.device_type - fp8_dtype = torch.float8_e4m3fn + fp8_dtype = e4m3_dtype x_fp32 = torch.rand(size, size, device=device, requires_grad=True) local_weight = torch.rand(2 * size, size, device=device, requires_grad=True) diff --git a/test/test_everything.sh b/test/test_everything.sh index c7817f9..3933da2 100755 --- a/test/test_everything.sh +++ b/test/test_everything.sh @@ -2,13 +2,19 @@ # terminate script on first error set -e +IS_ROCM=$(rocm-smi --version) pytest test/test_base.py pytest test/test_sam.py pytest test/test_compile.py + +# These tests do not work on ROCm yet +if [ -z "$IS_ROCM" ] +then ./test/test_fsdp.sh ./test/test_fsdp_compile.sh ./test/test_dtensor.sh pytest test/test_fsdp2/test_fsdp2_eager.py +fi echo "all tests successful" diff --git a/test/test_sam.py b/test/test_sam.py index d8eb719..9341241 100644 --- a/test/test_sam.py +++ b/test/test_sam.py @@ -18,7 +18,7 @@ swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, ) -from float8_experimental.float8_utils import compute_error +from float8_experimental.float8_utils import compute_error, IS_ROCM from transformers import SamModel is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) @@ -31,6 +31,7 @@ class TestFloat8SAMIntegrationTest: @pytest.mark.parametrize("data_dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("linear_type", [Float8Linear, Float8DynamicLinear]) @pytest.mark.skipif(not is_H100, reason="requires H100 GPU") + @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw(self, data_dtype, linear_type): model = SamModel.from_pretrained("facebook/sam-vit-base").to(data_dtype).cuda() # print(model)