From 9a4b4ef5974f2bf01f7d8037cfe3609b0e6d0a1b Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 2 Jul 2025 13:57:53 -0700 Subject: [PATCH] Rename torchao.float8.Float8Tensor to torchao.float8.Float8TrainingTensor Summary: att, since we are introducing a inference version Float8Tensor Test Plan: regression tests for float8 training: pytest test/float8 Reviewers: Subscribers: Tasks: Tags: stack-info: PR: https://github.com/pytorch/ao/pull/2479, branch: jerryzh168/stack/11 --- benchmarks/float8/bench_linear_float8.py | 2 +- benchmarks/float8/bench_padding.py | 2 +- test/float8/test_base.py | 10 +- test/float8/test_compile.py | 18 ++-- test/float8/test_dtensor.py | 10 +- test/float8/test_fsdp2/test_fsdp2.py | 2 +- .../moe_training/test_scaled_grouped_mm.py | 2 +- test/quantization/test_qat.py | 4 +- torchao/dtypes/nf4tensor.py | 2 +- torchao/float8/__init__.py | 10 +- torchao/float8/distributed_utils.py | 4 +- torchao/float8/float8_linear.py | 2 +- torchao/float8/float8_ops.py | 101 ++++++++++-------- torchao/float8/float8_scaling_utils.py | 8 +- torchao/float8/float8_tensor_parallel.py | 14 +-- ...t8_tensor.py => float8_training_tensor.py} | 24 ++--- torchao/float8/fsdp_utils.py | 18 ++-- torchao/float8/inference.py | 2 +- .../float8nocompile/float8nocompile_linear.py | 6 +- .../float8nocompile_scaling_utils.py | 2 +- .../kernels/fp8_dynamic_tensorwise.py | 56 +++++----- .../kernels/fp8_dynamic_tensorwise_test.py | 2 +- 22 files changed, 160 insertions(+), 141 deletions(-) rename torchao/float8/{float8_tensor.py => float8_training_tensor.py} (94%) diff --git a/benchmarks/float8/bench_linear_float8.py b/benchmarks/float8/bench_linear_float8.py index a7b1e17934..6d55bcc173 100644 --- a/benchmarks/float8/bench_linear_float8.py +++ b/benchmarks/float8/bench_linear_float8.py @@ -23,7 +23,7 @@ ScalingType, ) from torchao.float8.float8_linear import Float8Linear -from torchao.float8.float8_tensor import ScaledMMConfig +from torchao.float8.float8_training_tensor import ScaledMMConfig # estimating TOPs for matmuls in fp32, fp16, fp8 # assuming A * B = C, with A being M * K, B being K * N, C being M * N diff --git a/benchmarks/float8/bench_padding.py b/benchmarks/float8/bench_padding.py index eed8a5b542..62a161637b 100644 --- a/benchmarks/float8/bench_padding.py +++ b/benchmarks/float8/bench_padding.py @@ -12,7 +12,7 @@ from torch._inductor.utils import do_bench_using_profiling from tqdm import tqdm -from torchao.float8.float8_tensor import ( +from torchao.float8.float8_training_tensor import ( GemmInputRole, LinearMMConfig, ScaledMMConfig, diff --git a/test/float8/test_base.py b/test/float8/test_base.py index df86c6f04e..d54c08a3a3 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -40,8 +40,8 @@ get_maybe_axiswise_dim, hp_tensor_to_float8_dynamic, ) -from torchao.float8.float8_tensor import ( - Float8Tensor, +from torchao.float8.float8_training_tensor import ( + Float8TrainingTensor, GemmInputRole, LinearMMConfig, ScaledMMConfig, @@ -60,13 +60,13 @@ torch.manual_seed(0) -def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: +def bitwise_identical(a: Float8TrainingTensor, b: Float8TrainingTensor) -> bool: assert torch.all(a._scale == b._scale).item(), "scales are not identical" assert torch.all(a._data == b._data).item(), "data is not identical" return True -class TestFloat8Tensor: +class TestFloat8TrainingTensor: def test_preserves_dtype(self) -> None: # hp means high precision, lp means low precision hp_dtypes = (torch.float32, torch.float16, torch.bfloat16) @@ -128,7 +128,7 @@ def test_copy_(self): with pytest.raises(RuntimeError): fp8_a.copy_(b) # Should fail - fp8_b = Float8Tensor( + fp8_b = Float8TrainingTensor( torch.empty(16, dtype=e4m3_dtype), scale_a, torch.bfloat16, diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index aaf9d3d3f5..a196d87430 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -36,7 +36,11 @@ from torchao.float8.float8_scaling_utils import ( hp_tensor_to_float8_dynamic, ) -from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig +from torchao.float8.float8_training_tensor import ( + GemmInputRole, + LinearMMConfig, + ScaledMMConfig, +) from torchao.testing.training.test_utils import get_test_float8_linear_config @@ -238,7 +242,7 @@ def forward(self, x): "CUDA with capability 9.0 or greater not available", ) def test_float8_with_graph_break_in_the_middle(self): - """Test that having Float8Tensor object at the boundary of a subgraph""" + """Test that having Float8TrainingTensor object at the boundary of a subgraph""" cnts = CompileCounterWithBackend("inductor") mod = self.MockLinear(graph_break=True).cuda() compiled_mod = copy.deepcopy(mod) @@ -254,7 +258,7 @@ def test_float8_with_graph_break_in_the_middle(self): "CUDA with float8 support not available", ) def test_float8_graph_input(self): - """Test that having Float8Tensor object as a graph input""" + """Test that having Float8TrainingTensor object as a graph input""" def to_float(x): return x.to_original_precision() @@ -278,7 +282,7 @@ def to_float(x): "CUDA with float8 support not available", ) def test_float8_graph_output(self): - """Test that having Float8Tensor object as a graph output works""" + """Test that having Float8TrainingTensor object as a graph output works""" cnts = CompileCounterWithBackend("inductor") mod = self.MockLinear(graph_break=False).cuda() compiled_mod = torch.compile(mod, backend=cnts) @@ -290,14 +294,14 @@ def test_float8_graph_output(self): for tensor in tensors: assert not isinstance( getattr(y_compiled, tensor), torch._subclasses.fake_tensor.FakeTensor - ), "Float8Tensor should not contain any FakeTensors!" + ), "Float8TrainingTensor should not contain any FakeTensors!" assert isinstance(y_compiled._orig_dtype, torch.dtype), ( - "Float8Tensor._orig_dtype should be a dtype but got {}".format( + "Float8TrainingTensor._orig_dtype should be a dtype but got {}".format( type(y_compiled._orig_dtype) ) ) assert isinstance(y_compiled._linear_mm_config.output.emulate, bool), ( - "Float8Tensor._emulate should be a bool but got {}".format( + "Float8TrainingTensor._emulate should be a bool but got {}".format( type(y_compiled._linear_mm_config.output.emulate) ) ) diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index 5509eb1cc2..56f13248d3 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -37,8 +37,8 @@ ) from torchao.float8.float8_linear_utils import convert_to_float8_training from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic -from torchao.float8.float8_tensor import ( - Float8Tensor, +from torchao.float8.float8_training_tensor import ( + Float8TrainingTensor, GemmInputRole, LinearMMConfig, hp_tensor_and_scale_to_float8, @@ -94,8 +94,8 @@ def _test_scaled_mm(mesh: DeviceMesh, size=16): dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [lhs_placement], run_check=False) dist_y_fp8 = DTensor.from_local(y_fp8, mesh, [rhs_placement], run_check=False) - assert isinstance(dist_x_fp8.to_local(), Float8Tensor) - assert isinstance(dist_y_fp8.to_local(), Float8Tensor) + assert isinstance(dist_x_fp8.to_local(), Float8TrainingTensor) + assert isinstance(dist_y_fp8.to_local(), Float8TrainingTensor) assert dist_x_fp8.to_local()._orig_dtype == torch.float32 out_fp8 = torch.mm(dist_x_fp8, dist_y_fp8) local_fp8_out = out_fp8.to_local() @@ -128,7 +128,7 @@ def _test_fp8_redistribute(mesh: DeviceMesh, size=16): if isinstance(out_local, AsyncCollectiveTensor): out_local = out_local.wait() - assert isinstance(out_local, Float8Tensor) + assert isinstance(out_local, Float8TrainingTensor) assert out_local._data.dtype == fp8_dtype diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index b4c7f9fd15..ef87e5fcda 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -41,7 +41,7 @@ from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType from torchao.float8.float8_linear_utils import convert_to_float8_training from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic -from torchao.float8.float8_tensor import GemmInputRole +from torchao.float8.float8_training_tensor import GemmInputRole from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor from torchao.testing.training.fsdp2_utils import ( check_parity_bf16_mp, diff --git a/test/prototype/moe_training/test_scaled_grouped_mm.py b/test/prototype/moe_training/test_scaled_grouped_mm.py index 844220c49c..3b4d23965b 100644 --- a/test/prototype/moe_training/test_scaled_grouped_mm.py +++ b/test/prototype/moe_training/test_scaled_grouped_mm.py @@ -24,7 +24,7 @@ Float8LinearRecipeName, ) from torchao.float8.float8_linear import matmul_with_hp_or_float8_args -from torchao.float8.float8_tensor import LinearMMConfig +from torchao.float8.float8_training_tensor import LinearMMConfig from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated from torchao.prototype.moe_training.scaled_grouped_mm import ( _scaled_grouped_mm, diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index f0404a2ac2..de79ea4122 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -19,7 +19,7 @@ from torchao import quantize_ from torchao.float8.config import ScalingGranularity from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic -from torchao.float8.float8_tensor import LinearMMConfig +from torchao.float8.float8_training_tensor import LinearMMConfig from torchao.quantization.granularity import ( PerAxis, PerGroup, @@ -1696,7 +1696,7 @@ def test_qat_range_learning(self): def test_float8_rowwise_fake_quantize(self): """ - Test that `_Float8RowwiseFakeQuantize` is numerically close to `Float8Tensor`. + Test that `_Float8RowwiseFakeQuantize` is numerically close to `Float8TrainingTensor`. """ torch.manual_seed(self.SEED) dtype = torch.float8_e4m3fn diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index e6662b350a..4764e8b69b 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -943,7 +943,7 @@ def allowed_subclasses(type): f"NF4Tensor dispatch: attempting to run {func}, this is not supported" ) - # Do not force the Float8Tensor type on the returned tensor + # Do not force the Float8TrainingTensor type on the returned tensor @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): diff --git a/torchao/float8/__init__.py b/torchao/float8/__init__.py index 4f90292918..170d0ddd81 100644 --- a/torchao/float8/__init__.py +++ b/torchao/float8/__init__.py @@ -10,8 +10,8 @@ _auto_filter_for_recipe, convert_to_float8_training, ) -from torchao.float8.float8_tensor import ( - Float8Tensor, +from torchao.float8.float8_training_tensor import ( + Float8TrainingTensor, GemmInputRole, LinearMMConfig, ScaledMMConfig, @@ -22,12 +22,12 @@ from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 if TORCH_VERSION_AT_LEAST_2_5: - # Needed to load Float8Tensor with weights_only = True + # Needed to load Float8TrainingTensor with weights_only = True from torch.serialization import add_safe_globals add_safe_globals( [ - Float8Tensor, + Float8TrainingTensor, ScaledMMConfig, GemmInputRole, LinearMMConfig, @@ -50,5 +50,5 @@ "_auto_filter_for_recipe", # types "FP8Granularity", - # note: Float8Tensor and Float8Linear are not public APIs + # note: Float8TrainingTensor and Float8Linear are not public APIs ] diff --git a/torchao/float8/distributed_utils.py b/torchao/float8/distributed_utils.py index cd1560fabd..a278640af8 100644 --- a/torchao/float8/distributed_utils.py +++ b/torchao/float8/distributed_utils.py @@ -8,7 +8,7 @@ import torch.distributed._functional_collectives as funcol from torch.distributed._tensor import DTensor -from torchao.float8.float8_tensor import Float8Tensor +from torchao.float8.float8_training_tensor import Float8TrainingTensor def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: @@ -16,7 +16,7 @@ def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: Check if the tensor is already casted to fp8, works if the local tensor is wrapped in DTensor. """ - if isinstance(tensor, Float8Tensor): + if isinstance(tensor, Float8TrainingTensor): return True elif isinstance(tensor, DTensor): # TODO: shall we stick to public API and directly use tensor.to_local() here? diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index fbafc1a393..95102a873d 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -17,7 +17,7 @@ get_maybe_axiswise_dim, hp_tensor_to_float8_dynamic, ) -from torchao.float8.float8_tensor import ( +from torchao.float8.float8_training_tensor import ( GemmInputRole, LinearMMConfig, ScaledMMConfig, diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index 7e5432c6c5..58b018d0c0 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -8,7 +8,10 @@ import torch from torch.utils._pytree import tree_map -from torchao.float8.float8_tensor import Float8Tensor, choose_scaled_mm_config +from torchao.float8.float8_training_tensor import ( + Float8TrainingTensor, + choose_scaled_mm_config, +) from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul aten = torch.ops.aten @@ -18,7 +21,7 @@ # [Note] Usage of scales -# The meaning of scale in this library can be found in the definition of the Float8Tensor +# The meaning of scale in this library can be found in the definition of the Float8TrainingTensor # Cublas defines scale to always mean a multiplicative factor for the respective matrices # For a,b going from fp8 -> fp32 we multiple by the inverse of the scale # For output going from fp32 -> fp8 we multiply by the scale @@ -33,7 +36,7 @@ def addmm_float8_unwrapped( use_fast_accum: bool = False, ) -> torch.Tensor: """ - This is the unwrapped version of addmm_float8, which does not take in Float8Tensors + This is the unwrapped version of addmm_float8, which does not take in Float8TrainingTensors as inputs. This is used to standardize the logic between subclassed and non subclassed versions of the linear module. """ @@ -124,7 +127,7 @@ def decorator(func): def float8_desugar_op(aten_op, args, kwargs=None): _assert_tensorwise_scale(aten_op, args[0]._scale) new_data = aten_op(args[0]._data, *args[1:], **kwargs) - return Float8Tensor( + return Float8TrainingTensor( new_data, args[0]._scale, args[0]._orig_dtype, @@ -141,7 +144,7 @@ def float8_desugar_op(aten_op, args, kwargs=None): def float8_desugar_data_and_scale_op(aten_op, args, kwargs=None): new_data = aten_op(args[0]._data, *args[1:], **kwargs) new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) - return Float8Tensor( + return Float8TrainingTensor( new_data, new_scale, args[0]._orig_dtype, @@ -174,7 +177,7 @@ def float8_transpose(aten_op, args, kwargs=None): else: new_axiswise_dim == 0 - return Float8Tensor( + return Float8TrainingTensor( new_data, new_scale, args[0]._orig_dtype, @@ -192,7 +195,7 @@ def float8_view(aten_op, args, kwargs=None): # note that we have to create a new wrapper to make PyTorch internals happy if new_shape == list(t._data.shape): new_data = aten_op(args[0]._data, *args[1:], **kwargs) - return Float8Tensor( + return Float8TrainingTensor( new_data, args[0]._scale, args[0]._orig_dtype, @@ -212,7 +215,7 @@ def float8_view(aten_op, args, kwargs=None): new_data = aten_op(t._data, new_shape, **kwargs) new_scale_shape = [1, new_shape[-1]] new_scale = aten_op(t._scale, new_scale_shape, **kwargs) - return Float8Tensor( + return Float8TrainingTensor( new_data, new_scale, t._orig_dtype, @@ -225,7 +228,7 @@ def float8_view(aten_op, args, kwargs=None): new_scale_shape = [new_shape[0], 1] new_scale = aten_op(t._scale, new_scale_shape, **kwargs) new_axiswise_dim = -1 - return Float8Tensor( + return Float8TrainingTensor( new_data, new_scale, t._orig_dtype, @@ -245,7 +248,7 @@ def float8_split(aten_op, args, kwargs=None): _assert_tensorwise_scale(aten_op, args[0]._scale) def make_float8(data): - return Float8Tensor( + return Float8TrainingTensor( data, args[0]._scale, args[0]._orig_dtype, @@ -260,7 +263,7 @@ def make_float8(data): # Errors cant `cat_cuda float8 e4m3fn` @implements([aten.cat.default]) def float8_cat(aten_op, args, kwargs=None): - chunked_tensors: Tuple[Float8Tensor] = args[0] + chunked_tensors: Tuple[Float8TrainingTensor] = args[0] orig_dtype = chunked_tensors[0]._orig_dtype scale = chunked_tensors[0]._scale @@ -269,8 +272,8 @@ def float8_cat(aten_op, args, kwargs=None): gemm_input_role = chunked_tensors[0]._gemm_input_role chunk_data = [] for chunk in chunked_tensors: - assert isinstance(chunk, Float8Tensor), ( - "Expecting all chunks to be of type Float8Tensor" + assert isinstance(chunk, Float8TrainingTensor), ( + "Expecting all chunks to be of type Float8TrainingTensor" ) assert chunk._orig_dtype == orig_dtype, ( "Expecting all chunks to be of the same dtype" @@ -292,7 +295,7 @@ def float8_cat(aten_op, args, kwargs=None): new_data = aten_op(chunk_data, *args[1:], **kwargs) new_data = new_data.view(fp8_dtype) - return Float8Tensor(new_data, scale, orig_dtype, mm_config, gemm_input_role) + return Float8TrainingTensor(new_data, scale, orig_dtype, mm_config, gemm_input_role) @implements([aten.sum.dim_IntList]) @@ -307,7 +310,7 @@ def float8_cast_up_op(aten_op, args, kwargs=None): _assert_tensorwise_scale(aten_op, args[0]._scale) def unwrap(x): - if isinstance(x, Float8Tensor): + if isinstance(x, Float8TrainingTensor): return x.to_original_precision() return x @@ -316,7 +319,7 @@ def unwrap(x): return aten_op(*new_args, **new_kwargs) -def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): +def preprocess_addmm(a: Float8TrainingTensor, b: Float8TrainingTensor): a_data = a._data a_scale = a._scale b_data = b._data @@ -362,10 +365,10 @@ def float8_mm(aten_op, args, kwargs=None): a = args[0] b = args[1] - assert isinstance(a, Float8Tensor) and isinstance(b, Float8Tensor), ( - "Expecting both Float8Tensor for mm inputs but found {} and {}".format( - type(a), type(b) - ) + assert isinstance(a, Float8TrainingTensor) and isinstance( + b, Float8TrainingTensor + ), "Expecting both Float8TrainingTensor for mm inputs but found {} and {}".format( + type(a), type(b) ) a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) output_dtype = a._orig_dtype @@ -396,8 +399,8 @@ def float8_mm(aten_op, args, kwargs=None): def float8_addmm(aten_op, args, kwargs=None): assert ( isinstance(args[0], torch.Tensor) - and isinstance(args[1], Float8Tensor) - and isinstance(args[2], Float8Tensor) + and isinstance(args[1], Float8TrainingTensor) + and isinstance(args[2], Float8TrainingTensor) ) bias = args[0] a = args[1] @@ -438,19 +441,19 @@ def float8_is_same_size(aten_op, args, kwargs=None): @implements([aten._to_copy.default]) def autocast_to_copy(aten_op, args, kwargs=None): """This gets called when running matmul under autocast - when the input is a Float8Tensor, presenting as a fp32 + when the input is a Float8TrainingTensor, presenting as a fp32 tensor. """ _assert_tensorwise_scale(aten_op, args[0]._scale) - assert isinstance(args[0], Float8Tensor) + assert isinstance(args[0], Float8TrainingTensor) assert len(kwargs) == 1 and "dtype" in kwargs, ( "Only support dtype kwarg for autocast" ) assert kwargs["dtype"] in { torch.float16, torch.bfloat16, - }, "Only support floating point conversion for autocast w/ Float8Tensor" - return Float8Tensor( + }, "Only support floating point conversion for autocast w/ Float8TrainingTensor" + return Float8TrainingTensor( args[0]._data, args[0]._scale, kwargs["dtype"], @@ -471,14 +474,14 @@ def allgather_fp8(aten_op, args, kwargs=None): """ _assert_tensorwise_scale(aten_op, args[0]._scale) fp8_input = args[0] - assert isinstance(fp8_input, Float8Tensor), ( - f"expecting a Float8Tensor for allgather but found {type(fp8_input)}" + assert isinstance(fp8_input, Float8TrainingTensor), ( + f"expecting a Float8TrainingTensor for allgather but found {type(fp8_input)}" ) fp8_data = fp8_input._data fp8_data = fp8_data.contiguous() fp8_out = aten_op(fp8_data, *args[1:], **kwargs) - return Float8Tensor( + return Float8TrainingTensor( fp8_out, fp8_input._scale, fp8_input._orig_dtype, @@ -491,11 +494,11 @@ def allgather_fp8(aten_op, args, kwargs=None): def wait_tensor_fp8(aten_op, args, kwargs=None): _assert_tensorwise_scale(aten_op, args[0]._scale) fp8_input = args[0] - assert isinstance(fp8_input, Float8Tensor) + assert isinstance(fp8_input, Float8TrainingTensor) fp8_data = fp8_input._data fp8_out = aten_op(fp8_data, *args[1:], **kwargs) - return Float8Tensor( + return Float8TrainingTensor( fp8_out, fp8_input._scale, fp8_input._orig_dtype, @@ -508,8 +511,8 @@ def wait_tensor_fp8(aten_op, args, kwargs=None): def index_put_fp8(aten_op, args, kwargs=None): fp8_self = args[0] fp8_values = args[2] - assert isinstance(fp8_self, Float8Tensor) - assert isinstance(fp8_values, Float8Tensor) + assert isinstance(fp8_self, Float8TrainingTensor) + assert isinstance(fp8_values, Float8TrainingTensor) _assert_tensorwise_scale(fp8_self, args[0]._scale) assert fp8_self._scale == fp8_values._scale assert fp8_self.dtype == fp8_values.dtype @@ -518,7 +521,7 @@ def index_put_fp8(aten_op, args, kwargs=None): fp8_data = fp8_self._data fp8_values_data = fp8_values._data fp8_out = aten_op(fp8_data, args[1], fp8_values_data, *args[3:], **kwargs) - return Float8Tensor( + return Float8TrainingTensor( fp8_out, fp8_self._scale, fp8_self._orig_dtype, @@ -529,39 +532,43 @@ def index_put_fp8(aten_op, args, kwargs=None): @implements([aten.copy_.default]) def copy_fp8(aten_op, args, kwargs=None): - # For a copy op with Float8Tensors involved, only the following combinations are allowed: - # 1. self is a high precision (hp) tensor, src is a Float8Tensor: + # For a copy op with Float8TrainingTensors involved, only the following combinations are allowed: + # 1. self is a high precision (hp) tensor, src is a Float8TrainingTensor: # in this case src is upcasted and unscaled to go into the hp tensor - # 2. self and src are Float8Tensors: - # the copy is only allowed if all the Float8Tensor properties are equal (a la torch.cat) + # 2. self and src are Float8TrainingTensors: + # the copy is only allowed if all the Float8TrainingTensor properties are equal (a la torch.cat) # Every other combination is banned as the semantics are not well defined self = args[0] src = args[1] - if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): + if not isinstance(self, Float8TrainingTensor) and isinstance( + src, Float8TrainingTensor + ): src_hp = src.to_original_precision() _assert_tensorwise_scale(aten_op, src._scale) return aten_op(self, src_hp, *args[2:], **kwargs) - elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): + elif isinstance(self, Float8TrainingTensor) and isinstance( + src, Float8TrainingTensor + ): _assert_tensorwise_scale(aten_op, src._scale) assert self._orig_dtype == src._orig_dtype, ( - "Expecting both Float8Tensors to be of the same dtype" + "Expecting both Float8TrainingTensors to be of the same dtype" ) assert self._scale == src._scale, ( - "Expecting both Float8Tensors to have thee same scale" + "Expecting both Float8TrainingTensors to have thee same scale" ) assert self._linear_mm_config == src._linear_mm_config, ( - "Expecting both Float8Tensors to have thee same mm config" + "Expecting both Float8TrainingTensors to have thee same mm config" ) assert self._data.dtype == src._data.dtype, ( - "Expecting both Float8Tensors to be of the same dtypet" + "Expecting both Float8TrainingTensors to be of the same dtypet" ) assert self._gemm_input_role == src._gemm_input_role, ( - "Expecting both Float8Tensors to have the same gemm_input_role" + "Expecting both Float8TrainingTensors to have the same gemm_input_role" ) fp8_out = aten_op(self._data, src._data, *args[2:], **kwargs) - return Float8Tensor( + return Float8TrainingTensor( fp8_out, self._scale, self._orig_dtype, @@ -569,4 +576,4 @@ def copy_fp8(aten_op, args, kwargs=None): self._gemm_input_role, ) else: - raise RuntimeError("Unsupported semantics for copy_ in Float8Tensor") + raise RuntimeError("Unsupported semantics for copy_ in Float8TrainingTensor") diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index 31f2db6b4e..5a9138a1e9 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -14,8 +14,8 @@ from torchao.float8.config import ScalingGranularity from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 -from torchao.float8.float8_tensor import ( - Float8Tensor, +from torchao.float8.float8_training_tensor import ( + Float8TrainingTensor, GemmInputRole, LinearMMConfig, hp_tensor_and_scale_to_float8, @@ -36,10 +36,10 @@ def hp_tensor_to_float8_dynamic( scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, round_scales_to_power_of_2: bool = False, -) -> Float8Tensor: +) -> Float8TrainingTensor: """ Given a high precision tensor `hp_tensor`, - scales `hp_tensor` dynamically and returns a `Float8Tensor` of the result. + scales `hp_tensor` dynamically and returns a `Float8TrainingTensor` of the result. Args: hp_tensor: the tensor to convert diff --git a/torchao/float8/float8_tensor_parallel.py b/torchao/float8/float8_tensor_parallel.py index 36ae6d587e..175712c231 100644 --- a/torchao/float8/float8_tensor_parallel.py +++ b/torchao/float8/float8_tensor_parallel.py @@ -19,7 +19,7 @@ NoopFwToFloat8BwDynamic, hp_tensor_to_float8_dynamic, ) -from torchao.float8.float8_tensor import GemmInputRole +from torchao.float8.float8_training_tensor import GemmInputRole # subclass the ColwiseParallel and RowwiseParallel classes # to add the float8 support @@ -62,7 +62,7 @@ def _prepare_input_fn( mod.config.cast_config_input.target_dtype, mod.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, - ) # DTensor(Float8Tensor) + ) # DTensor(Float8TrainingTensor) # transform the input layouts to the desired layouts of ColwiseParallel if input_layouts != desired_input_layouts: @@ -79,7 +79,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me placements=output_layouts, async_op=True ) # DTensor(torch.Tensor) - # fwd noop bwd cast to DTensor(Float8Tensor) + # fwd noop bwd cast to DTensor(Float8TrainingTensor) outputs = NoopFwToFloat8BwDynamic.apply( outputs, mod.linear_mm_config, @@ -126,7 +126,7 @@ def _prepare_input_fn( mod.config.cast_config_input.target_dtype, mod.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, - ) # DTensor(Float8Tensor) + ) # DTensor(Float8TrainingTensor) if input_layouts != desired_input_layouts: input_tensor = input_tensor.redistribute( @@ -142,7 +142,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me if outputs.placements != output_layouts: outputs = outputs.redistribute(placements=output_layouts, async_op=True) - # fwd noop bwd cast to DTensor(Float8Tensor) + # fwd noop bwd cast to DTensor(Float8TrainingTensor) outputs = NoopFwToFloat8BwDynamic.apply( outputs, mod.linear_mm_config, @@ -173,7 +173,7 @@ class PrepareFloat8ModuleInput(PrepareModuleInput): currently assumes tensorwise scaling. The only difference from `PrepareModuleInput` is that - after we prepare the input DTensor, we cast the input to DTensor(Float8Tensor) + after we prepare the input DTensor, we cast the input to DTensor(Float8TrainingTensor) This is to ensure the float8 cast happens before the all-gather (i.e. Shard -> Replicate) so that if there are multiple float8 users of the input activation, we perform fp8 allgather only once. @@ -234,7 +234,7 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): e4m3_dtype, self.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, - ) # DTensor(Float8Tensor) + ) # DTensor(Float8TrainingTensor) if desired_layout is not None and input_layout != desired_layout: dt_inp = dt_inp.redistribute(placements=(desired_layout,)) diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_training_tensor.py similarity index 94% rename from torchao/float8/float8_tensor.py rename to torchao/float8/float8_training_tensor.py index 6b5177e1fe..96c5c9e086 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_training_tensor.py @@ -66,7 +66,7 @@ class LinearMMConfig(NamedTuple): Configuration for different gemm operations in LinearMM. This configuration is not user-facing and exists for convenience, - allowing Float8Tensor to use the right config based on which gemm + allowing Float8TrainingTensor to use the right config based on which gemm from gemms with outputs `output`, `grad_input`, `grad_weight` is being called. Attributes: @@ -82,7 +82,7 @@ class LinearMMConfig(NamedTuple): class GemmInputRole(enum.Enum): """ - Given a Float8Tensor, the enum below describes the expected role of this + Given a Float8TrainingTensor, the enum below describes the expected role of this tensor in the three gemms present in the fw + bw pass of a Linear layer. This is used to choose the right config for a float8 gemm when the gemm is performed. @@ -138,7 +138,7 @@ def forward( axiswise_dim: Optional[int] = None, ): """ - This function will apply the scaling, and then convert to a Float8Tensor + This function will apply the scaling, and then convert to a Float8TrainingTensor Note: We will call this function with a DTensor subclass. Ideally this would be an aten OP @@ -161,7 +161,7 @@ def forward( bits_placements = bits_fp8.placements local_bits = bits_fp8.to_local() local_scale = scale.to_local() - inner_float8_tensor = Float8Tensor( + inner_float8_tensor = Float8TrainingTensor( local_bits, local_scale, tensor.dtype, @@ -178,7 +178,7 @@ def forward( stride=bits_fp8.stride(), ) - return Float8Tensor( + return Float8TrainingTensor( bits_fp8, scale, tensor.dtype, @@ -219,10 +219,10 @@ def hp_tensor_and_scale_to_float8( ): """ Given a high precision tensor `hp_tensor` and a precalculated scale `s`, - scales `hp_tensor` by `s` and returns a `Float8Tensor` of the result. + scales `hp_tensor` by `s` and returns a `Float8TrainingTensor` of the result. Autograd-aware, the derivative is pass-through. - DTensor-aware, if the input is a DTensor the output will be DTensor(Float8Tensor). + DTensor-aware, if the input is a DTensor the output will be DTensor(Float8TrainingTensor). Args: hp_tensor: the tensor to convert @@ -239,7 +239,7 @@ def hp_tensor_and_scale_to_float8( ) -class Float8Tensor(torch.Tensor): +class Float8TrainingTensor(torch.Tensor): """ Note: this is **not** a public API and is only intended to be used inside of this repository. Please file an issue if you would benefit @@ -319,7 +319,7 @@ def __new__( return self def __repr__(self): - return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}, axiswise_dim={self._axiswise_dim}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}" + return f"Float8TrainingTensor(dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}, axiswise_dim={self._axiswise_dim}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}" def __tensor_flatten__(self): ctx = { @@ -333,7 +333,7 @@ def __tensor_flatten__(self): @staticmethod def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride): assert len(inner_tensors) == 2 - return Float8Tensor( + return Float8TrainingTensor( inner_tensors["_data"], inner_tensors["_scale"], metadata["_orig_dtype"], @@ -355,7 +355,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # Lazy import to avoid circular dependency from torchao.float8.float8_ops import FLOAT8_OPS_TABLE - # All ops in the FLOAT8_OPS_TABLE expect Float8Tensor as inputs + # All ops in the FLOAT8_OPS_TABLE expect Float8TrainingTensor as inputs # And don't support mixed tensor subclasses. This will trigger the handler for # the next type in the dispatch list def allowed_subclasses(type): @@ -374,5 +374,5 @@ def allowed_subclasses(type): return FLOAT8_OPS_TABLE[func](func, args, kwargs) raise NotImplementedError(f"attempting to run {func}, this is not supported") - # Do not force the Float8Tensor type on the returned tensor + # Do not force the Float8TrainingTensor type on the returned tensor __torch_function__ = torch._C._disabled_torch_function_impl diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index 7b24dc2b53..9c3379278b 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -15,8 +15,8 @@ from torchao.float8.float8_scaling_utils import ( hp_tensor_to_float8_dynamic, ) -from torchao.float8.float8_tensor import ( - Float8Tensor, +from torchao.float8.float8_training_tensor import ( + Float8TrainingTensor, GemmInputRole, LinearMMConfig, hp_tensor_and_scale_to_float8, @@ -217,7 +217,7 @@ def __repr__(self): def fsdp_pre_all_gather(self, mesh): if self._precomputed_scale is not None: - float8_tensor = hp_tensor_and_scale_to_float8( + float8_training_tensor = hp_tensor_and_scale_to_float8( self._tensor, self._precomputed_scale, self._dtype, @@ -225,7 +225,7 @@ def fsdp_pre_all_gather(self, mesh): GemmInputRole.WEIGHT, ) else: - float8_tensor = hp_tensor_to_float8_dynamic( + float8_training_tensor = hp_tensor_to_float8_dynamic( self._tensor, self._dtype, self._linear_mm_config, @@ -233,7 +233,7 @@ def fsdp_pre_all_gather(self, mesh): gemm_input_role=GemmInputRole.WEIGHT, device_mesh=mesh, ) - return (float8_tensor._data,), (float8_tensor._scale,) + return (float8_training_tensor._data,), (float8_training_tensor._scale,) def fsdp_post_all_gather( self, @@ -248,18 +248,18 @@ def fsdp_post_all_gather( if out is not None: from torch.distributed._tensor import DTensor - if isinstance(out, Float8Tensor): + if isinstance(out, Float8TrainingTensor): out._scale = scale elif isinstance(out, DTensor) and isinstance( - out._local_tensor, Float8Tensor + out._local_tensor, Float8TrainingTensor ): out._local_tensor._scale = scale else: raise RuntimeError( - f"out must be a Float8Tensor or DTensor(_local_tensor=Float8Tensor), but got {out}" + f"out must be a Float8TrainingTensor or DTensor(_local_tensor=Float8TrainingTensor), but got {out}" ) return - return Float8Tensor( + return Float8TrainingTensor( data, scale, param_dtype, diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index 144f1fa6f2..0a766adb45 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -78,7 +78,7 @@ def addmm_float8_unwrapped_inference( use_fast_accum: bool = False, ) -> Tensor: """ - This is the unwrapped version of addmm_float8, which does not take in Float8Tensors + This is the unwrapped version of addmm_float8, which does not take in Float8TrainingTensors as inputs. This is used to standardize the logic between subclassed and non subclassed versions of the linear module. """ diff --git a/torchao/prototype/float8nocompile/float8nocompile_linear.py b/torchao/prototype/float8nocompile/float8nocompile_linear.py index 7e0eb85022..b7ee306066 100644 --- a/torchao/prototype/float8nocompile/float8nocompile_linear.py +++ b/torchao/prototype/float8nocompile/float8nocompile_linear.py @@ -11,7 +11,11 @@ import torch from torchao.float8.config import Float8LinearConfig -from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig +from torchao.float8.float8_training_tensor import ( + GemmInputRole, + LinearMMConfig, + ScaledMMConfig, +) from torchao.prototype.float8nocompile.float8nocompile_scaling_utils import ( ToFP8ColumnMajor, ToFP8ColumnMajorT, diff --git a/torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py b/torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py index 7b6a25e3f9..1e55c0c2e9 100644 --- a/torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py +++ b/torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py @@ -10,7 +10,7 @@ import torch -from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig +from torchao.float8.float8_training_tensor import GemmInputRole, LinearMMConfig from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import ( KernelAlgorithm, hp_to_fp8_col_major, diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py index 3786b52eb5..37c7611980 100644 --- a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py @@ -14,7 +14,11 @@ import triton import triton.language as tl -from torchao.float8.float8_tensor import Float8Tensor, GemmInputRole, LinearMMConfig +from torchao.float8.float8_training_tensor import ( + Float8TrainingTensor, + GemmInputRole, + LinearMMConfig, +) EPS = 1e-12 @@ -487,7 +491,7 @@ def _scale_atomic( tl.float32 ) - # store scale for use in Float8Tensor constructor + # store scale for use in Float8TrainingTensor constructor scale_off = tl.arange(0, 1) tl.store(scale_out_ptr + scale_off, scale) @@ -541,7 +545,7 @@ def hp_to_fp8_row_major( linear_mm_config: LinearMMConfig, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, -) -> Float8Tensor: +) -> Float8TrainingTensor: assert hp_tensor.is_contiguous(), "input tensor must be contiguous" num_elements = hp_tensor.numel() @@ -576,8 +580,8 @@ def hp_to_fp8_row_major( EPS=EPS, ) - # wrap output tensor in Float8Tensor - fp8_tensor_row_major = Float8Tensor( + # wrap output tensor in Float8TrainingTensor + fp8_tensor_row_major = Float8TrainingTensor( output_buffer, scale, orig_dtype=hp_tensor.dtype, @@ -593,7 +597,7 @@ def hp_to_fp8_row_major_t( linear_mm_config: LinearMMConfig, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, -) -> Float8Tensor: +) -> Float8TrainingTensor: assert hp_tensor.is_contiguous(), "input tensor must be contiguous" num_elements = hp_tensor.numel() @@ -641,8 +645,8 @@ def hp_to_fp8_row_major_t( EPS=EPS, ) - # wrap output tensor in Float8Tensor - fp8_tensor_row_major_t = Float8Tensor( + # wrap output tensor in Float8TrainingTensor + fp8_tensor_row_major_t = Float8TrainingTensor( output_buffer, scale, orig_dtype=hp_tensor.dtype, @@ -658,7 +662,7 @@ def hp_to_fp8_col_major( linear_mm_config: LinearMMConfig, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, -) -> Float8Tensor: +) -> Float8TrainingTensor: assert hp_tensor.is_contiguous(), "input tensor must be contiguous" num_elements = hp_tensor.numel() @@ -705,8 +709,8 @@ def hp_to_fp8_col_major( col_major_strides = (1, num_rows) output_buffer = output_buffer.as_strided(output_buffer.size(), col_major_strides) - # wrap output tensor in Float8Tensor - fp8_tensor_col_major = Float8Tensor( + # wrap output tensor in Float8TrainingTensor + fp8_tensor_col_major = Float8TrainingTensor( output_buffer, scale, orig_dtype=hp_tensor.dtype, @@ -722,7 +726,7 @@ def hp_to_fp8_col_major_t( linear_mm_config: LinearMMConfig, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, -) -> Float8Tensor: +) -> Float8TrainingTensor: assert hp_tensor.is_contiguous(), "input tensor must be contiguous" num_elements = hp_tensor.numel() @@ -757,8 +761,8 @@ def hp_to_fp8_col_major_t( EPS=EPS, ) - # wrap output tensor in Float8Tensor - fp8_tensor_col_major_t = Float8Tensor( + # wrap output tensor in Float8TrainingTensor + fp8_tensor_col_major_t = Float8TrainingTensor( output_buffer, scale, orig_dtype=hp_tensor.dtype, @@ -774,7 +778,7 @@ def hp_to_fp8_row_and_col_major( linear_mm_config: LinearMMConfig, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, -) -> Float8Tensor: +) -> Float8TrainingTensor: assert hp_tensor.is_contiguous(), "input tensor must be contiguous" tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] @@ -830,15 +834,15 @@ def hp_to_fp8_row_and_col_major( fp8_output_col_major.size(), col_major_strides ) - # wrap outputs in Float8Tensors - fp8_tensor_row_major = Float8Tensor( + # wrap outputs in Float8TrainingTensors + fp8_tensor_row_major = Float8TrainingTensor( fp8_output_row_major, scale, orig_dtype=hp_tensor.dtype, linear_mm_config=linear_mm_config, gemm_input_role=gemm_input_role, ) - fp8_tensor_col_major = Float8Tensor( + fp8_tensor_col_major = Float8TrainingTensor( fp8_output_col_major, scale, orig_dtype=hp_tensor.dtype, @@ -854,7 +858,7 @@ def hp_to_fp8_row_major_t_and_non_t( linear_mm_config: LinearMMConfig, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, -) -> Float8Tensor: +) -> Float8TrainingTensor: assert hp_tensor.is_contiguous(), "input tensor must be contiguous" tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] @@ -912,15 +916,15 @@ def hp_to_fp8_row_major_t_and_non_t( EPS=EPS, ) - # wrap outputs in Float8Tensors - fp8_tensor_row_major = Float8Tensor( + # wrap outputs in Float8TrainingTensors + fp8_tensor_row_major = Float8TrainingTensor( fp8_output_row_major, scale, orig_dtype=hp_tensor.dtype, linear_mm_config=linear_mm_config, gemm_input_role=gemm_input_role, ) - fp8_tensor_row_major_t = Float8Tensor( + fp8_tensor_row_major_t = Float8TrainingTensor( fp8_output_row_major_t, scale, orig_dtype=hp_tensor.dtype, @@ -936,7 +940,7 @@ def hp_to_fp8_col_major_t_and_non_t( linear_mm_config: LinearMMConfig, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, -) -> Float8Tensor: +) -> Float8TrainingTensor: assert hp_tensor.is_contiguous(), "input tensor must be contiguous" tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] @@ -999,15 +1003,15 @@ def hp_to_fp8_col_major_t_and_non_t( fp8_output_col_major.size(), col_major_strides ) - # wrap outputs in Float8Tensors - fp8_tensor_col_major = Float8Tensor( + # wrap outputs in Float8TrainingTensors + fp8_tensor_col_major = Float8TrainingTensor( fp8_output_col_major, scale, orig_dtype=hp_tensor.dtype, linear_mm_config=linear_mm_config, gemm_input_role=gemm_input_role, ) - fp8_tensor_col_major_t = Float8Tensor( + fp8_tensor_col_major_t = Float8TrainingTensor( fp8_output_col_major_t, scale, orig_dtype=hp_tensor.dtype, diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py index 2348877d5c..0d7a20fae7 100644 --- a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py @@ -7,7 +7,7 @@ import torch from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic -from torchao.float8.float8_tensor import LinearMMConfig +from torchao.float8.float8_training_tensor import LinearMMConfig from torchao.float8.float8_utils import is_row_major from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import ( KernelAlgorithm,