diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 588d48a..da5aec4 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -41,7 +41,7 @@ def decorator(func): @implements( [ - aten.view.default, + # aten.view.default, aten._unsafe_view.default, aten.as_strided.default, aten.clone.default, @@ -79,6 +79,19 @@ def float8_desugar_data_and_scale(aten_op, args, kwargs=None): args[0]._gemm_input_role, ) +@implements([aten.view.default]) +def float8_view(aten_op, args, kwargs=None): + if len(args[0]._scale.shape) < 2: + # tensorwise scaling + return float8_desugar_op(aten_op, *args, **kwargs) + print('args', args) + print('kwargs', kwargs) + tensor, new_shape = args[0], args[1] + + # for now, only support reshaping to [-1, *dims] or [*dims, -1] + if len(new_shape) >= 2 and (new_shape[0] == -1 or new_shape[-1] == -1): + return float8_desugar_data_and_scale(aten_op, *args, **kwargs) + raise AssertionError(f"{aten_op} with axiswise scaling and shape {new_shape} is not supported yet.") @implements([aten.split.Tensor]) def float8_split(aten_op, args, kwargs=None): diff --git a/float8_experimental/float8_python_api.py b/float8_experimental/float8_python_api.py index 001eff4..59edd8d 100644 --- a/float8_experimental/float8_python_api.py +++ b/float8_experimental/float8_python_api.py @@ -39,13 +39,6 @@ def addmm_float8_unwrapped( a_inverse_scale = a_scale.reciprocal() b_inverse_scale = b_scale.reciprocal() - # TODO: should we change torch._scaled_mm? - # torch._scaled_mm expects rowwise scaled scales to be of rank 1, not rank - # 2. Translate to this format. - # TODO: audit if we need to make this more generic for various shapes. - a_inverse_scale = a_inverse_scale.squeeze() - b_inverse_scale = b_inverse_scale.squeeze() - if output_dtype == torch.float32 and bias is not None: # Bias is not supported by _scaled_mm when output is fp32 output = torch._scaled_mm( diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 26fde8a..500d05e 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -115,9 +115,9 @@ def tensor_to_amax( # convert from axiswise_dim (dim to keep) to # dim as the input to the `torch.amax` function (tuple of dims to reduce) - dim_to_reduce = tuple(d for d in range(len(x.shape)) if d != axiswise_dim) + # dim_to_reduce = tuple(d for d in range(len(x.shape)) if d != axiswise_dim) - amax = torch.amax(torch.abs(x), dim=dim_to_reduce, keepdim=True) + amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will diff --git a/test/test_base.py b/test/test_base.py index 38fed52..2825fc5 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -63,7 +63,7 @@ def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: return True -class TestFloat8Tensor(unittest.TestCase): +class TestFloat8Tensor: def test_preserves_dtype(self) -> None: # hp means high precision, lp means low precision hp_dtypes = (torch.float32, torch.float16, torch.bfloat16) @@ -73,7 +73,7 @@ def test_preserves_dtype(self) -> None: x1_s = tensor_to_scale(x1_hp, lp_dtype) x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype) x3_hp = x2_lp.to_original_precision() - self.assertTrue(x3_hp.dtype == hp_dtype) + assert x3_hp.dtype == hp_dtype def test_differentiable_casts(self) -> None: lp_dtypes = (e4m3_dtype, e5m2_dtype) @@ -108,7 +108,7 @@ def test_index_put(self): fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn) fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn) - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): b[index] = fp8_a fp8_b[index] = a fp8_b_bad[index] = fp8_a @@ -122,7 +122,7 @@ def test_copy_(self): b = torch.empty(16, dtype=torch.bfloat16) b.copy_(fp8_a) # Should work torch.testing.assert_close(b, fp8_a.to_original_precision()) - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): fp8_a.copy_(b) # Should fail fp8_b = Float8Tensor( @@ -149,9 +149,33 @@ def test_weights_only_load(self): buffer.seek(0) _ = torch.load(buffer, weights_only=True) - def test_axiswise_dynamic_cast(self): - a = torch.randn(16, 32, dtype=torch.bfloat16) + @pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)]) + @pytest.mark.parametrize("dim_name", ["first", "last"]) + def test_axiswise_dynamic_cast(self, shape, dim_name): + a = torch.randn(*shape, dtype=torch.bfloat16) + + if dim_name == "first": + dim = 0 + elif dim_name == "last": + dim = len(a.shape) - 1 + linear_mm_config = LinearMMConfig() + a_fp8 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=dim, + ) + a_dq = a_fp8.to_original_precision() + sqnr = compute_error(a, a_dq) + assert sqnr >= 25.0 + + # TODO(next) make this work + def test_axiswise_reshape(self): + a = torch.randn(3, 5, 7, dtype=torch.bfloat16, device="cuda") + linear_mm_config = LinearMMConfig() + a_fp8 = hp_tensor_to_float8_dynamic( a, e4m3_dtype, @@ -159,11 +183,15 @@ def test_axiswise_dynamic_cast(self): scaling_granularity=ScalingGranularity.AXISWISE, axiswise_dim=0, ) - # print(a_fp8) - # print(a_fp8.to_original_precision()) - # print(a_fp8.t()) - b = a_fp8.t() - # TODO check numerical accuracy + # a_fp8._data.shape is (3, 5, 7) + # a_fp8._scale.shape is (1, 5, 7) + print(a_fp8._scale.shape) + + # reshape to (3, 5 * 7) + # a_fp8._scale.shape should be (1, 5 * 7) + a_fp8_r = a_fp8.reshape(3, -1) + print(a_fp8_r._scale.shape) + def test_axiswise_gemm(self): a = torch.randn(16, 32, dtype=torch.bfloat16, device="cuda") @@ -177,7 +205,7 @@ def test_axiswise_gemm(self): linear_mm_config, gemm_input_role=GemmInputRole.INPUT, scaling_granularity=ScalingGranularity.AXISWISE, - axiswise_dim=0, + axiswise_dim=1, ) b_fp8 = hp_tensor_to_float8_dynamic( b, @@ -185,10 +213,13 @@ def test_axiswise_gemm(self): linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, scaling_granularity=ScalingGranularity.AXISWISE, - axiswise_dim=0, + axiswise_dim=1, ) - c = torch.mm(a_fp8, b_fp8.t()) - print(c) + c_fp8_compute = torch.mm(a_fp8, b_fp8.t()) + print(c_fp8_compute) + c_ref = torch.mm(a, b.t()) + sqnr = compute_error(c_ref, c_fp8_compute) + print('sqnr', sqnr) # TODO check numerical accuracy