diff --git a/benchmarks/bench_padding.py b/benchmarks/bench_padding.py index 16463d5..3cb77c1 100644 --- a/benchmarks/bench_padding.py +++ b/benchmarks/bench_padding.py @@ -6,9 +6,9 @@ import torch from float8_experimental.float8_tensor import ( GemmInputRole, + hp_tensor_and_scale_to_float8, LinearMMConfig, ScaledMMConfig, - ToFloat8ConstrFunc, ) from float8_experimental.float8_utils import pad_tensor_for_matmul from tabulate import tabulate @@ -58,14 +58,14 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype): a_config = LinearMMConfig(a_config, a_config, a_config) b_config = LinearMMConfig(b_config, b_config, b_config) - a_fp8 = ToFloat8ConstrFunc.apply( + a_fp8 = hp_tensor_and_scale_to_float8( A, scale_a, fp8_dtype, a_config, GemmInputRole.INPUT, ) - b_fp8 = ToFloat8ConstrFunc.apply( + b_fp8 = hp_tensor_and_scale_to_float8( B, scale_b, fp8_dtype, diff --git a/float8_experimental/float8_scaling_utils.py b/float8_experimental/float8_scaling_utils.py index f319e75..a590d62 100644 --- a/float8_experimental/float8_scaling_utils.py +++ b/float8_experimental/float8_scaling_utils.py @@ -15,10 +15,10 @@ from float8_experimental.float8_tensor import ( Float8Tensor, GemmInputRole, + hp_tensor_and_scale_to_float8, LinearMMConfig, ScaledMMConfig, tensor_already_casted_to_fp8, - ToFloat8ConstrFunc, ) from float8_experimental.float8_utils import ( @@ -39,7 +39,7 @@ def cast_to_float8_e4m3_dynamic( if tensor_already_casted_to_fp8(inpt_tensor): return inpt_tensor scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) - return ToFloat8ConstrFunc.apply( + return hp_tensor_and_scale_to_float8( inpt_tensor, scale, e4m3_dtype, @@ -58,7 +58,7 @@ def cast_to_float8_delayed( gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, ): amax_buffer.fill_(tensor_to_amax(tensor)) - return ToFloat8ConstrFunc.apply( + return hp_tensor_and_scale_to_float8( tensor, scale, float8_dtype, @@ -145,7 +145,7 @@ def backward(ctx, go): fp8_amax_grad_output.fill_(tensor_to_amax(go)) - res = ToFloat8ConstrFunc.apply( + res = hp_tensor_and_scale_to_float8( go, fp8_scale_grad_output, e5m2_dtype, @@ -177,7 +177,7 @@ def backward(ctx, gradY): if tensor_already_casted_to_fp8(gradY): return gradY, None gradY_scale = tensor_to_scale(gradY, e5m2_dtype) - fp8_tensor = ToFloat8ConstrFunc.apply( + fp8_tensor = hp_tensor_and_scale_to_float8( gradY, gradY_scale, e5m2_dtype, diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 62ce38d..641f972 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -129,7 +129,7 @@ def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: @torch._dynamo.allow_in_graph -class ToFloat8ConstrFunc(torch.autograd.Function): +class _ToFloat8ConstrFunc(torch.autograd.Function): """ A differentiable conversion to fp8. * forward: convert from high precision to float8 @@ -154,15 +154,6 @@ def forward( with that composing with FakeTensor, so we special case here. DTensor Invariant: DTensor must always be the outer most tensor subclass - - Args: - tensor: the tensor to convert - scale: the scale to use to convert the tensor - float8_dtype: the float8 dtype to use - linear_mm_config: Defines the configuration for the scaled_mm for - the 3 fwd/bwd gemms of linear - gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in - the 3 fwd/bwd gemms of linear """ tensor_scaled = tensor * scale bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) @@ -205,7 +196,7 @@ def backward(ctx, g): @torch._dynamo.allow_in_graph -class FromFloat8ConstrFunc(torch.autograd.Function): +class _FromFloat8ConstrFunc(torch.autograd.Function): """ A differentiable conversion from fp8. * forward: convert from float8 to high precision @@ -221,6 +212,34 @@ def backward(ctx, g): return g, None, None +def hp_tensor_and_scale_to_float8( + hp_tensor: torch.Tensor, + s: torch.Tensor, + float8_dtype=e4m3_dtype, + linear_mm_config: Optional[LinearMMConfig] = None, + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, +): + """ + Given a high precision tensor `hp_tensor` and a precalculated scale `s`, + scales `hp_tensor` by `s` and returns a `Float8Tensor` of the result. + + Autograd-aware, the derivative is pass-through. + DTensor-aware, if the input is a DTensor the output will be DTensor(Float8Tensor). + + Args: + hp_tensor: the tensor to convert + s: the scale to use to convert the tensor + float8_dtype: the float8 dtype to use + linear_mm_config: Defines the configuration for the scaled_mm for + the 3 fwd/bwd gemms of linear + gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in + the 3 fwd/bwd gemms of linear + """ + return _ToFloat8ConstrFunc.apply( + hp_tensor, s, float8_dtype, linear_mm_config, gemm_input_role + ) + + class Float8Tensor(torch.Tensor): """ Note: this is **not** a public API and is only intended to be used @@ -309,7 +328,7 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride ) def to_original_precision(self): - return FromFloat8ConstrFunc.apply(self) + return _FromFloat8ConstrFunc.apply(self) @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index 1cb4788..702df26 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -18,8 +18,8 @@ from float8_experimental.float8_tensor import ( Float8Tensor, GemmInputRole, + hp_tensor_and_scale_to_float8, LinearMMConfig, - ToFloat8ConstrFunc, ) from float8_experimental.float8_utils import e4m3_dtype, EPS @@ -167,7 +167,7 @@ def __repr__(self): def fsdp_pre_all_gather(self, mesh): if self._precomputed_scale is not None: - float8_tensor = ToFloat8ConstrFunc.apply( + float8_tensor = hp_tensor_and_scale_to_float8( self._tensor, self._precomputed_scale, torch.float8_e4m3fn, diff --git a/float8_experimental/inference.py b/float8_experimental/inference.py index 717695f..21c8794 100644 --- a/float8_experimental/inference.py +++ b/float8_experimental/inference.py @@ -19,10 +19,10 @@ from float8_experimental.float8_tensor import ( Float8Tensor, GemmInputRole, + hp_tensor_and_scale_to_float8, LinearMMConfig, ScaledMMConfig, tensor_already_casted_to_fp8, - ToFloat8ConstrFunc, ) from float8_experimental.float8_utils import e4m3_dtype, tensor_to_scale @@ -127,7 +127,7 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None: self.weight, Float8Tensor ), "Weight has already been quantized, cannot quantize again." scale = tensor_to_scale(self.weight, dtype) - quantized_weight = ToFloat8ConstrFunc.apply( + quantized_weight = hp_tensor_and_scale_to_float8( self.weight, scale, dtype, @@ -200,7 +200,7 @@ def cast_to_float8_e4m3_inference( if static_quantization_scale is not None else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) ) - return ToFloat8ConstrFunc.apply( + return hp_tensor_and_scale_to_float8( inpt_tensor, scale, e4m3_dtype, diff --git a/test/test_base.py b/test/test_base.py index 94baec3..82d0c60 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -28,9 +28,9 @@ from float8_experimental.float8_tensor import ( Float8Tensor, GemmInputRole, + hp_tensor_and_scale_to_float8, LinearMMConfig, ScaledMMConfig, - ToFloat8ConstrFunc, ) from float8_experimental.float8_utils import ( compute_error, @@ -66,7 +66,7 @@ def test_preserves_dtype(self) -> None: for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes): x1_hp = torch.randn(4, 4, dtype=hp_dtype) x1_s = tensor_to_scale(x1_hp, lp_dtype) - x2_lp = ToFloat8ConstrFunc.apply(x1_hp, x1_s, lp_dtype) + x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype) x3_hp = x2_lp.to_original_precision() self.assertTrue(x3_hp.dtype == hp_dtype) @@ -76,7 +76,7 @@ def test_differentiable_casts(self) -> None: x = torch.randn(1).requires_grad_() grad = torch.randn(1) x_s = tensor_to_scale(x, f8_dtype) - x_f8 = ToFloat8ConstrFunc.apply(x, x_s, f8_dtype) + x_f8 = hp_tensor_and_scale_to_float8(x, x_s, f8_dtype) x_f8_hp = x_f8.to_original_precision() x_f8_hp.backward(grad) # the gradient should be unchanged through both casts @@ -85,7 +85,7 @@ def test_differentiable_casts(self) -> None: def test_split_cat(self): a = torch.rand(16, 16, dtype=torch.bfloat16) scale = tensor_to_scale(a, e4m3_dtype) - fp8_a = ToFloat8ConstrFunc.apply(a, scale, e4m3_dtype) + fp8_a = hp_tensor_and_scale_to_float8(a, scale, e4m3_dtype) splits = torch.split(fp8_a, 16) catted = torch.cat(splits, dim=0) @@ -94,14 +94,14 @@ def test_split_cat(self): def test_index_put(self): a = torch.rand(16, dtype=torch.bfloat16) scale_a = tensor_to_scale(a, torch.float8_e4m3fn) - fp8_a = ToFloat8ConstrFunc.apply(a, scale_a, torch.float8_e4m3fn) + fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, torch.float8_e4m3fn) index = torch.randint(0, 15, (16,), dtype=torch.long) b = torch.rand(16, 16, dtype=torch.bfloat16) scale_b = tensor_to_scale(b, torch.float8_e4m3fn) - fp8_b = ToFloat8ConstrFunc.apply(b, scale_a, torch.float8_e4m3fn) - fp8_b_bad = ToFloat8ConstrFunc.apply(b, scale_b, torch.float8_e4m3fn) + fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn) + fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn) with self.assertRaises(AssertionError): b[index] = fp8_a @@ -112,7 +112,7 @@ def test_index_put(self): def test_copy_(self): a = torch.rand(16, dtype=torch.bfloat16) scale_a = tensor_to_scale(a, torch.float8_e4m3fn) - fp8_a = ToFloat8ConstrFunc.apply(a, scale_a, torch.float8_e4m3fn) + fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, torch.float8_e4m3fn) b = torch.empty(16, dtype=torch.bfloat16) b.copy_(fp8_a) # Should work @@ -407,8 +407,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): a_scale = tensor_to_scale(a, input_dtype).float() b_scale = tensor_to_scale(b, input_dtype).float() - a_fp8 = ToFloat8ConstrFunc.apply(a, a_scale, input_dtype) - b_fp8 = ToFloat8ConstrFunc.apply(b, b_scale, input_dtype) + a_fp8 = hp_tensor_and_scale_to_float8(a, a_scale, input_dtype) + b_fp8 = hp_tensor_and_scale_to_float8(b, b_scale, input_dtype) out_scaled_mm = addmm_float8_unwrapped( a_fp8._data, @@ -447,14 +447,14 @@ def test_different_configs_error(self): ScaledMMConfig(True, False, False, False), ScaledMMConfig(True, False, False, False), ) - a = ToFloat8ConstrFunc.apply( + a = hp_tensor_and_scale_to_float8( x_fp32, x_scale, fp8_dtype, linear_config_a, GemmInputRole.INPUT, ) - b = ToFloat8ConstrFunc.apply( + b = hp_tensor_and_scale_to_float8( x_fp32, x_scale, fp8_dtype, @@ -486,10 +486,10 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): a_scale = tensor_to_scale(a, input_dtype).float() b_scale = tensor_to_scale(b, input_dtype).float() - a_fp8 = ToFloat8ConstrFunc.apply( + a_fp8 = hp_tensor_and_scale_to_float8( a, a_scale, input_dtype, None, GemmInputRole.INPUT ) - b_fp8 = ToFloat8ConstrFunc.apply( + b_fp8 = hp_tensor_and_scale_to_float8( b, b_scale, input_dtype, None, GemmInputRole.WEIGHT ) @@ -506,14 +506,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): scaled_mm_config, scaled_mm_config, scaled_mm_config ) - a_fp8 = ToFloat8ConstrFunc.apply( + a_fp8 = hp_tensor_and_scale_to_float8( a, a_scale, input_dtype, pad_config, GemmInputRole.INPUT, ) - b_fp8 = ToFloat8ConstrFunc.apply( + b_fp8 = hp_tensor_and_scale_to_float8( b, b_scale, input_dtype, @@ -529,14 +529,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): emulated_scaled_mm_config, emulated_scaled_mm_config, ) - a_fp8 = ToFloat8ConstrFunc.apply( + a_fp8 = hp_tensor_and_scale_to_float8( a, a_scale, input_dtype, emulated_config, GemmInputRole.INPUT, ) - b_fp8 = ToFloat8ConstrFunc.apply( + b_fp8 = hp_tensor_and_scale_to_float8( b, b_scale, input_dtype, @@ -695,19 +695,19 @@ def test_fp8_tensor_statistics(self): # Overflow caused by a too large scaling factor s_overflow = torch.tensor(1e9) - fp8_overflow = ToFloat8ConstrFunc.apply(x1_hp, s_overflow, lp_dtype) + fp8_overflow = hp_tensor_and_scale_to_float8(x1_hp, s_overflow, lp_dtype) (zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_overflow, lp_dtype) self.assertEqual((zero_cnt, max_cnt), (0, tensor_len)) # Underflow caused by a too small scaling factor s_underflow = torch.tensor(1e-9) - fp8_underflow = ToFloat8ConstrFunc.apply(x1_hp, s_underflow, lp_dtype) + fp8_underflow = hp_tensor_and_scale_to_float8(x1_hp, s_underflow, lp_dtype) (zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_underflow, lp_dtype) self.assertEqual((zero_cnt, max_cnt), (tensor_len, 0)) # Both overflow and underflow x2_hp = torch.cat((x1_hp * 1e9, x1_hp * 1.0, x1_hp * 1e-9), 0) - fp8_over_underflow = ToFloat8ConstrFunc.apply( + fp8_over_underflow = hp_tensor_and_scale_to_float8( x2_hp, torch.tensor(1.0), lp_dtype ) (zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_over_underflow, lp_dtype) diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 4d56a0d..34e12d2 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -20,8 +20,8 @@ from float8_experimental.float8_tensor import ( Float8Tensor, GemmInputRole, + hp_tensor_and_scale_to_float8, LinearMMConfig, - ToFloat8ConstrFunc, ) from float8_experimental.float8_tensor_parallel import ( Float8ColwiseParallel, @@ -87,10 +87,10 @@ def test_scaled_mm(mesh: DeviceMesh, size=16): x_scale = tensor_to_scale(x_fp32, fp8_dtype).float() y_scale = tensor_to_scale(y_fp32, fp8_dtype).float() - x_fp8 = ToFloat8ConstrFunc.apply( + x_fp8 = hp_tensor_and_scale_to_float8( x_fp32, x_scale, fp8_dtype, None, GemmInputRole.INPUT ) - y_fp8 = ToFloat8ConstrFunc.apply( + y_fp8 = hp_tensor_and_scale_to_float8( y_fp32, y_scale, fp8_dtype, None, GemmInputRole.WEIGHT ) @@ -117,7 +117,7 @@ def test_fp8_redistribute(mesh: DeviceMesh, size=16): x_scale = tensor_to_scale(x_fp32, fp8_dtype).float() - x_fp8 = ToFloat8ConstrFunc.apply(x_fp32, x_scale, fp8_dtype) + x_fp8 = hp_tensor_and_scale_to_float8(x_fp32, x_scale, fp8_dtype) dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [Shard(0)], run_check=False) out_dist = dist_x_fp8.redistribute(placements=[Replicate()]) @@ -145,7 +145,7 @@ def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16): dist_x_scale = tensor_to_scale(dist_x_fp32, fp8_dtype).float() assert isinstance(dist_x_scale, DTensor) - dist_x_fp8 = ToFloat8ConstrFunc.apply(dist_x_fp32, dist_x_scale, fp8_dtype) + dist_x_fp8 = hp_tensor_and_scale_to_float8(dist_x_fp32, dist_x_scale, fp8_dtype) assert isinstance(dist_x_fp8, DTensor) @@ -164,14 +164,14 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): dist_weight_scale = tensor_to_scale(dist_wight_fp32, fp8_dtype).float() dist_target = distribute_tensor(target, mesh, [Shard(0)]) - dist_x_fp8 = ToFloat8ConstrFunc.apply( + dist_x_fp8 = hp_tensor_and_scale_to_float8( dist_x_fp32, dist_x_scale, fp8_dtype, None, GemmInputRole.INPUT, ) - dist_weight_fp8 = ToFloat8ConstrFunc.apply( + dist_weight_fp8 = hp_tensor_and_scale_to_float8( dist_wight_fp32, dist_weight_scale, fp8_dtype,