Skip to content

Rename torchao.float8.Float8Tensor to torchao.float8.Float8TrainingTensor #2479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: jerryzh168/stack/4
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/float8/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
ScalingType,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_tensor import ScaledMMConfig
from torchao.float8.float8_training_tensor import ScaledMMConfig

# estimating TOPs for matmuls in fp32, fp16, fp8
# assuming A * B = C, with A being M * K, B being K * N, C being M * N
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/float8/bench_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch._inductor.utils import do_bench_using_profiling
from tqdm import tqdm

from torchao.float8.float8_tensor import (
from torchao.float8.float8_training_tensor import (
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
Expand Down
10 changes: 5 additions & 5 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
get_maybe_axiswise_dim,
hp_tensor_to_float8_dynamic,
)
from torchao.float8.float8_tensor import (
Float8Tensor,
from torchao.float8.float8_training_tensor import (
Float8TrainingTensor,
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
Expand All @@ -60,13 +60,13 @@
torch.manual_seed(0)


def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
def bitwise_identical(a: Float8TrainingTensor, b: Float8TrainingTensor) -> bool:
assert torch.all(a._scale == b._scale).item(), "scales are not identical"
assert torch.all(a._data == b._data).item(), "data is not identical"
return True


class TestFloat8Tensor:
class TestFloat8TrainingTensor:
def test_preserves_dtype(self) -> None:
# hp means high precision, lp means low precision
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
Expand Down Expand Up @@ -128,7 +128,7 @@ def test_copy_(self):
with pytest.raises(RuntimeError):
fp8_a.copy_(b) # Should fail

fp8_b = Float8Tensor(
fp8_b = Float8TrainingTensor(
torch.empty(16, dtype=e4m3_dtype),
scale_a,
torch.bfloat16,
Expand Down
18 changes: 11 additions & 7 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@
from torchao.float8.float8_scaling_utils import (
hp_tensor_to_float8_dynamic,
)
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig
from torchao.float8.float8_training_tensor import (
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
)
from torchao.testing.training.test_utils import get_test_float8_linear_config


Expand Down Expand Up @@ -238,7 +242,7 @@ def forward(self, x):
"CUDA with capability 9.0 or greater not available",
)
def test_float8_with_graph_break_in_the_middle(self):
"""Test that having Float8Tensor object at the boundary of a subgraph"""
"""Test that having Float8TrainingTensor object at the boundary of a subgraph"""
cnts = CompileCounterWithBackend("inductor")
mod = self.MockLinear(graph_break=True).cuda()
compiled_mod = copy.deepcopy(mod)
Expand All @@ -254,7 +258,7 @@ def test_float8_with_graph_break_in_the_middle(self):
"CUDA with float8 support not available",
)
def test_float8_graph_input(self):
"""Test that having Float8Tensor object as a graph input"""
"""Test that having Float8TrainingTensor object as a graph input"""

def to_float(x):
return x.to_original_precision()
Expand All @@ -278,7 +282,7 @@ def to_float(x):
"CUDA with float8 support not available",
)
def test_float8_graph_output(self):
"""Test that having Float8Tensor object as a graph output works"""
"""Test that having Float8TrainingTensor object as a graph output works"""
cnts = CompileCounterWithBackend("inductor")
mod = self.MockLinear(graph_break=False).cuda()
compiled_mod = torch.compile(mod, backend=cnts)
Expand All @@ -290,14 +294,14 @@ def test_float8_graph_output(self):
for tensor in tensors:
assert not isinstance(
getattr(y_compiled, tensor), torch._subclasses.fake_tensor.FakeTensor
), "Float8Tensor should not contain any FakeTensors!"
), "Float8TrainingTensor should not contain any FakeTensors!"
assert isinstance(y_compiled._orig_dtype, torch.dtype), (
"Float8Tensor._orig_dtype should be a dtype but got {}".format(
"Float8TrainingTensor._orig_dtype should be a dtype but got {}".format(
type(y_compiled._orig_dtype)
)
)
assert isinstance(y_compiled._linear_mm_config.output.emulate, bool), (
"Float8Tensor._emulate should be a bool but got {}".format(
"Float8TrainingTensor._emulate should be a bool but got {}".format(
type(y_compiled._linear_mm_config.output.emulate)
)
)
Expand Down
10 changes: 5 additions & 5 deletions test/float8/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
)
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic
from torchao.float8.float8_tensor import (
Float8Tensor,
from torchao.float8.float8_training_tensor import (
Float8TrainingTensor,
GemmInputRole,
LinearMMConfig,
hp_tensor_and_scale_to_float8,
Expand Down Expand Up @@ -94,8 +94,8 @@ def _test_scaled_mm(mesh: DeviceMesh, size=16):
dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [lhs_placement], run_check=False)
dist_y_fp8 = DTensor.from_local(y_fp8, mesh, [rhs_placement], run_check=False)

assert isinstance(dist_x_fp8.to_local(), Float8Tensor)
assert isinstance(dist_y_fp8.to_local(), Float8Tensor)
assert isinstance(dist_x_fp8.to_local(), Float8TrainingTensor)
assert isinstance(dist_y_fp8.to_local(), Float8TrainingTensor)
assert dist_x_fp8.to_local()._orig_dtype == torch.float32
out_fp8 = torch.mm(dist_x_fp8, dist_y_fp8)
local_fp8_out = out_fp8.to_local()
Expand Down Expand Up @@ -128,7 +128,7 @@ def _test_fp8_redistribute(mesh: DeviceMesh, size=16):
if isinstance(out_local, AsyncCollectiveTensor):
out_local = out_local.wait()

assert isinstance(out_local, Float8Tensor)
assert isinstance(out_local, Float8TrainingTensor)
assert out_local._data.dtype == fp8_dtype


Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.float8_tensor import GemmInputRole
from torchao.float8.float8_training_tensor import GemmInputRole
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
from torchao.testing.training.fsdp2_utils import (
check_parity_bf16_mp,
Expand Down
2 changes: 1 addition & 1 deletion test/prototype/moe_training/test_scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
Float8LinearRecipeName,
)
from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
from torchao.float8.float8_tensor import LinearMMConfig
from torchao.float8.float8_training_tensor import LinearMMConfig
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
from torchao.prototype.moe_training.scaled_grouped_mm import (
_scaled_grouped_mm,
Expand Down
4 changes: 2 additions & 2 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torchao import quantize_
from torchao.float8.config import ScalingGranularity
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.float8_tensor import LinearMMConfig
from torchao.float8.float8_training_tensor import LinearMMConfig
from torchao.quantization.granularity import (
PerAxis,
PerGroup,
Expand Down Expand Up @@ -1696,7 +1696,7 @@ def test_qat_range_learning(self):

def test_float8_rowwise_fake_quantize(self):
"""
Test that `_Float8RowwiseFakeQuantize` is numerically close to `Float8Tensor`.
Test that `_Float8RowwiseFakeQuantize` is numerically close to `Float8TrainingTensor`.
"""
torch.manual_seed(self.SEED)
dtype = torch.float8_e4m3fn
Expand Down
2 changes: 1 addition & 1 deletion torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ def allowed_subclasses(type):
f"NF4Tensor dispatch: attempting to run {func}, this is not supported"
)

# Do not force the Float8Tensor type on the returned tensor
# Do not force the Float8TrainingTensor type on the returned tensor

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
Expand Down
10 changes: 5 additions & 5 deletions torchao/float8/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
_auto_filter_for_recipe,
convert_to_float8_training,
)
from torchao.float8.float8_tensor import (
Float8Tensor,
from torchao.float8.float8_training_tensor import (
Float8TrainingTensor,
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
Expand All @@ -22,12 +22,12 @@
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if TORCH_VERSION_AT_LEAST_2_5:
# Needed to load Float8Tensor with weights_only = True
# Needed to load Float8TrainingTensor with weights_only = True
from torch.serialization import add_safe_globals

add_safe_globals(
[
Float8Tensor,
Float8TrainingTensor,
ScaledMMConfig,
GemmInputRole,
LinearMMConfig,
Expand All @@ -50,5 +50,5 @@
"_auto_filter_for_recipe",
# types
"FP8Granularity",
# note: Float8Tensor and Float8Linear are not public APIs
# note: Float8TrainingTensor and Float8Linear are not public APIs
]
4 changes: 2 additions & 2 deletions torchao/float8/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
import torch.distributed._functional_collectives as funcol
from torch.distributed._tensor import DTensor

from torchao.float8.float8_tensor import Float8Tensor
from torchao.float8.float8_training_tensor import Float8TrainingTensor


def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool:
"""
Check if the tensor is already casted to fp8, works if the local
tensor is wrapped in DTensor.
"""
if isinstance(tensor, Float8Tensor):
if isinstance(tensor, Float8TrainingTensor):
return True
elif isinstance(tensor, DTensor):
# TODO: shall we stick to public API and directly use tensor.to_local() here?
Expand Down
2 changes: 1 addition & 1 deletion torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
get_maybe_axiswise_dim,
hp_tensor_to_float8_dynamic,
)
from torchao.float8.float8_tensor import (
from torchao.float8.float8_training_tensor import (
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
Expand Down
Loading
Loading