Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[4/x] clean up casting: ToFloat8ConstrFunc -> hp_tensor_and_scale_to_float8 #348

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/bench_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import torch
from float8_experimental.float8_tensor import (
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
ScaledMMConfig,
ToFloat8ConstrFunc,
)
from float8_experimental.float8_utils import pad_tensor_for_matmul
from tabulate import tabulate
Expand Down Expand Up @@ -58,14 +58,14 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
a_config = LinearMMConfig(a_config, a_config, a_config)
b_config = LinearMMConfig(b_config, b_config, b_config)

a_fp8 = ToFloat8ConstrFunc.apply(
a_fp8 = hp_tensor_and_scale_to_float8(
A,
scale_a,
fp8_dtype,
a_config,
GemmInputRole.INPUT,
)
b_fp8 = ToFloat8ConstrFunc.apply(
b_fp8 = hp_tensor_and_scale_to_float8(
B,
scale_b,
fp8_dtype,
Expand Down
10 changes: 5 additions & 5 deletions float8_experimental/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
ScaledMMConfig,
tensor_already_casted_to_fp8,
ToFloat8ConstrFunc,
)

from float8_experimental.float8_utils import (
Expand All @@ -39,7 +39,7 @@ def cast_to_float8_e4m3_dynamic(
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
return ToFloat8ConstrFunc.apply(
return hp_tensor_and_scale_to_float8(
inpt_tensor,
scale,
e4m3_dtype,
Expand All @@ -58,7 +58,7 @@ def cast_to_float8_delayed(
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
amax_buffer.fill_(tensor_to_amax(tensor))
return ToFloat8ConstrFunc.apply(
return hp_tensor_and_scale_to_float8(
tensor,
scale,
float8_dtype,
Expand Down Expand Up @@ -145,7 +145,7 @@ def backward(ctx, go):

fp8_amax_grad_output.fill_(tensor_to_amax(go))

res = ToFloat8ConstrFunc.apply(
res = hp_tensor_and_scale_to_float8(
go,
fp8_scale_grad_output,
e5m2_dtype,
Expand Down Expand Up @@ -177,7 +177,7 @@ def backward(ctx, gradY):
if tensor_already_casted_to_fp8(gradY):
return gradY, None
gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
fp8_tensor = ToFloat8ConstrFunc.apply(
fp8_tensor = hp_tensor_and_scale_to_float8(
gradY,
gradY_scale,
e5m2_dtype,
Expand Down
43 changes: 31 additions & 12 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool:


@torch._dynamo.allow_in_graph
class ToFloat8ConstrFunc(torch.autograd.Function):
class _ToFloat8ConstrFunc(torch.autograd.Function):
"""
A differentiable conversion to fp8.
* forward: convert from high precision to float8
Expand All @@ -154,15 +154,6 @@ def forward(
with that composing with FakeTensor, so we special case here.

DTensor Invariant: DTensor must always be the outer most tensor subclass

Args:
tensor: the tensor to convert
scale: the scale to use to convert the tensor
float8_dtype: the float8 dtype to use
linear_mm_config: Defines the configuration for the scaled_mm for
the 3 fwd/bwd gemms of linear
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
the 3 fwd/bwd gemms of linear
"""
tensor_scaled = tensor * scale
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
Expand Down Expand Up @@ -205,7 +196,7 @@ def backward(ctx, g):


@torch._dynamo.allow_in_graph
class FromFloat8ConstrFunc(torch.autograd.Function):
class _FromFloat8ConstrFunc(torch.autograd.Function):
"""
A differentiable conversion from fp8.
* forward: convert from float8 to high precision
Expand All @@ -221,6 +212,34 @@ def backward(ctx, g):
return g, None, None


def hp_tensor_and_scale_to_float8(
hp_tensor: torch.Tensor,
s: torch.Tensor,
float8_dtype=e4m3_dtype,
linear_mm_config: Optional[LinearMMConfig] = None,
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
"""
Given a high precision tensor `hp_tensor` and a precalculated scale `s`,
scales `hp_tensor` by `s` and returns a `Float8Tensor` of the result.

Autograd-aware, the derivative is pass-through.
DTensor-aware, if the input is a DTensor the output will be DTensor(Float8Tensor).

Args:
hp_tensor: the tensor to convert
s: the scale to use to convert the tensor
float8_dtype: the float8 dtype to use
linear_mm_config: Defines the configuration for the scaled_mm for
the 3 fwd/bwd gemms of linear
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
the 3 fwd/bwd gemms of linear
"""
return _ToFloat8ConstrFunc.apply(
hp_tensor, s, float8_dtype, linear_mm_config, gemm_input_role
)


class Float8Tensor(torch.Tensor):
"""
Note: this is **not** a public API and is only intended to be used
Expand Down Expand Up @@ -309,7 +328,7 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride
)

def to_original_precision(self):
return FromFloat8ConstrFunc.apply(self)
return _FromFloat8ConstrFunc.apply(self)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
Expand Down
4 changes: 2 additions & 2 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
ToFloat8ConstrFunc,
)

from float8_experimental.float8_utils import e4m3_dtype, EPS
Expand Down Expand Up @@ -167,7 +167,7 @@ def __repr__(self):

def fsdp_pre_all_gather(self, mesh):
if self._precomputed_scale is not None:
float8_tensor = ToFloat8ConstrFunc.apply(
float8_tensor = hp_tensor_and_scale_to_float8(
self._tensor,
self._precomputed_scale,
torch.float8_e4m3fn,
Expand Down
6 changes: 3 additions & 3 deletions float8_experimental/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
ScaledMMConfig,
tensor_already_casted_to_fp8,
ToFloat8ConstrFunc,
)
from float8_experimental.float8_utils import e4m3_dtype, tensor_to_scale

Expand Down Expand Up @@ -127,7 +127,7 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
self.weight, Float8Tensor
), "Weight has already been quantized, cannot quantize again."
scale = tensor_to_scale(self.weight, dtype)
quantized_weight = ToFloat8ConstrFunc.apply(
quantized_weight = hp_tensor_and_scale_to_float8(
self.weight,
scale,
dtype,
Expand Down Expand Up @@ -200,7 +200,7 @@ def cast_to_float8_e4m3_inference(
if static_quantization_scale is not None
else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
)
return ToFloat8ConstrFunc.apply(
return hp_tensor_and_scale_to_float8(
inpt_tensor,
scale,
e4m3_dtype,
Expand Down
42 changes: 21 additions & 21 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
ScaledMMConfig,
ToFloat8ConstrFunc,
)
from float8_experimental.float8_utils import (
compute_error,
Expand Down Expand Up @@ -66,7 +66,7 @@ def test_preserves_dtype(self) -> None:
for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes):
x1_hp = torch.randn(4, 4, dtype=hp_dtype)
x1_s = tensor_to_scale(x1_hp, lp_dtype)
x2_lp = ToFloat8ConstrFunc.apply(x1_hp, x1_s, lp_dtype)
x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype)
x3_hp = x2_lp.to_original_precision()
self.assertTrue(x3_hp.dtype == hp_dtype)

Expand All @@ -76,7 +76,7 @@ def test_differentiable_casts(self) -> None:
x = torch.randn(1).requires_grad_()
grad = torch.randn(1)
x_s = tensor_to_scale(x, f8_dtype)
x_f8 = ToFloat8ConstrFunc.apply(x, x_s, f8_dtype)
x_f8 = hp_tensor_and_scale_to_float8(x, x_s, f8_dtype)
x_f8_hp = x_f8.to_original_precision()
x_f8_hp.backward(grad)
# the gradient should be unchanged through both casts
Expand All @@ -85,7 +85,7 @@ def test_differentiable_casts(self) -> None:
def test_split_cat(self):
a = torch.rand(16, 16, dtype=torch.bfloat16)
scale = tensor_to_scale(a, e4m3_dtype)
fp8_a = ToFloat8ConstrFunc.apply(a, scale, e4m3_dtype)
fp8_a = hp_tensor_and_scale_to_float8(a, scale, e4m3_dtype)

splits = torch.split(fp8_a, 16)
catted = torch.cat(splits, dim=0)
Expand All @@ -94,14 +94,14 @@ def test_split_cat(self):
def test_index_put(self):
a = torch.rand(16, dtype=torch.bfloat16)
scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
fp8_a = ToFloat8ConstrFunc.apply(a, scale_a, torch.float8_e4m3fn)
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, torch.float8_e4m3fn)

index = torch.randint(0, 15, (16,), dtype=torch.long)

b = torch.rand(16, 16, dtype=torch.bfloat16)
scale_b = tensor_to_scale(b, torch.float8_e4m3fn)
fp8_b = ToFloat8ConstrFunc.apply(b, scale_a, torch.float8_e4m3fn)
fp8_b_bad = ToFloat8ConstrFunc.apply(b, scale_b, torch.float8_e4m3fn)
fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn)
fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn)

with self.assertRaises(AssertionError):
b[index] = fp8_a
Expand All @@ -112,7 +112,7 @@ def test_index_put(self):
def test_copy_(self):
a = torch.rand(16, dtype=torch.bfloat16)
scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
fp8_a = ToFloat8ConstrFunc.apply(a, scale_a, torch.float8_e4m3fn)
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, torch.float8_e4m3fn)

b = torch.empty(16, dtype=torch.bfloat16)
b.copy_(fp8_a) # Should work
Expand Down Expand Up @@ -407,8 +407,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
a_scale = tensor_to_scale(a, input_dtype).float()
b_scale = tensor_to_scale(b, input_dtype).float()

a_fp8 = ToFloat8ConstrFunc.apply(a, a_scale, input_dtype)
b_fp8 = ToFloat8ConstrFunc.apply(b, b_scale, input_dtype)
a_fp8 = hp_tensor_and_scale_to_float8(a, a_scale, input_dtype)
b_fp8 = hp_tensor_and_scale_to_float8(b, b_scale, input_dtype)

out_scaled_mm = addmm_float8_unwrapped(
a_fp8._data,
Expand Down Expand Up @@ -447,14 +447,14 @@ def test_different_configs_error(self):
ScaledMMConfig(True, False, False, False),
ScaledMMConfig(True, False, False, False),
)
a = ToFloat8ConstrFunc.apply(
a = hp_tensor_and_scale_to_float8(
x_fp32,
x_scale,
fp8_dtype,
linear_config_a,
GemmInputRole.INPUT,
)
b = ToFloat8ConstrFunc.apply(
b = hp_tensor_and_scale_to_float8(
x_fp32,
x_scale,
fp8_dtype,
Expand Down Expand Up @@ -486,10 +486,10 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
a_scale = tensor_to_scale(a, input_dtype).float()
b_scale = tensor_to_scale(b, input_dtype).float()

a_fp8 = ToFloat8ConstrFunc.apply(
a_fp8 = hp_tensor_and_scale_to_float8(
a, a_scale, input_dtype, None, GemmInputRole.INPUT
)
b_fp8 = ToFloat8ConstrFunc.apply(
b_fp8 = hp_tensor_and_scale_to_float8(
b, b_scale, input_dtype, None, GemmInputRole.WEIGHT
)

Expand All @@ -506,14 +506,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
scaled_mm_config, scaled_mm_config, scaled_mm_config
)

a_fp8 = ToFloat8ConstrFunc.apply(
a_fp8 = hp_tensor_and_scale_to_float8(
a,
a_scale,
input_dtype,
pad_config,
GemmInputRole.INPUT,
)
b_fp8 = ToFloat8ConstrFunc.apply(
b_fp8 = hp_tensor_and_scale_to_float8(
b,
b_scale,
input_dtype,
Expand All @@ -529,14 +529,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
emulated_scaled_mm_config,
emulated_scaled_mm_config,
)
a_fp8 = ToFloat8ConstrFunc.apply(
a_fp8 = hp_tensor_and_scale_to_float8(
a,
a_scale,
input_dtype,
emulated_config,
GemmInputRole.INPUT,
)
b_fp8 = ToFloat8ConstrFunc.apply(
b_fp8 = hp_tensor_and_scale_to_float8(
b,
b_scale,
input_dtype,
Expand Down Expand Up @@ -695,19 +695,19 @@ def test_fp8_tensor_statistics(self):

# Overflow caused by a too large scaling factor
s_overflow = torch.tensor(1e9)
fp8_overflow = ToFloat8ConstrFunc.apply(x1_hp, s_overflow, lp_dtype)
fp8_overflow = hp_tensor_and_scale_to_float8(x1_hp, s_overflow, lp_dtype)
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_overflow, lp_dtype)
self.assertEqual((zero_cnt, max_cnt), (0, tensor_len))

# Underflow caused by a too small scaling factor
s_underflow = torch.tensor(1e-9)
fp8_underflow = ToFloat8ConstrFunc.apply(x1_hp, s_underflow, lp_dtype)
fp8_underflow = hp_tensor_and_scale_to_float8(x1_hp, s_underflow, lp_dtype)
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_underflow, lp_dtype)
self.assertEqual((zero_cnt, max_cnt), (tensor_len, 0))

# Both overflow and underflow
x2_hp = torch.cat((x1_hp * 1e9, x1_hp * 1.0, x1_hp * 1e-9), 0)
fp8_over_underflow = ToFloat8ConstrFunc.apply(
fp8_over_underflow = hp_tensor_and_scale_to_float8(
x2_hp, torch.tensor(1.0), lp_dtype
)
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_over_underflow, lp_dtype)
Expand Down
Loading
Loading