Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Update on "[wip] add axiswise granularity to Float8Tensor"
Browse files Browse the repository at this point in the history
Summary:

This PR adds the axiswise scaling granularity to `Float8Tensor` and
ensures that basic ops like transpose and `torch._scaled_mm` work as
expected.

A future PR will add integration with `Float8Linear`.

Test Plan:

TODO

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
vkuzo committed Jul 26, 2024
1 parent e87f005 commit c4c9ae8
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 25 deletions.
15 changes: 14 additions & 1 deletion float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def decorator(func):

@implements(
[
aten.view.default,
# aten.view.default,
aten._unsafe_view.default,
aten.as_strided.default,
aten.clone.default,
Expand Down Expand Up @@ -79,6 +79,19 @@ def float8_desugar_data_and_scale(aten_op, args, kwargs=None):
args[0]._gemm_input_role,
)

@implements([aten.view.default])
def float8_view(aten_op, args, kwargs=None):
if len(args[0]._scale.shape) < 2:
# tensorwise scaling
return float8_desugar_op(aten_op, *args, **kwargs)
print('args', args)
print('kwargs', kwargs)
tensor, new_shape = args[0], args[1]

# for now, only support reshaping to [-1, *dims] or [*dims, -1]
if len(new_shape) >= 2 and (new_shape[0] == -1 or new_shape[-1] == -1):
return float8_desugar_data_and_scale(aten_op, *args, **kwargs)
raise AssertionError(f"{aten_op} with axiswise scaling and shape {new_shape} is not supported yet.")

@implements([aten.split.Tensor])
def float8_split(aten_op, args, kwargs=None):
Expand Down
7 changes: 0 additions & 7 deletions float8_experimental/float8_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,6 @@ def addmm_float8_unwrapped(
a_inverse_scale = a_scale.reciprocal()
b_inverse_scale = b_scale.reciprocal()

# TODO: should we change torch._scaled_mm?
# torch._scaled_mm expects rowwise scaled scales to be of rank 1, not rank
# 2. Translate to this format.
# TODO: audit if we need to make this more generic for various shapes.
a_inverse_scale = a_inverse_scale.squeeze()
b_inverse_scale = b_inverse_scale.squeeze()

if output_dtype == torch.float32 and bias is not None:
# Bias is not supported by _scaled_mm when output is fp32
output = torch._scaled_mm(
Expand Down
4 changes: 2 additions & 2 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def tensor_to_amax(

# convert from axiswise_dim (dim to keep) to
# dim as the input to the `torch.amax` function (tuple of dims to reduce)
dim_to_reduce = tuple(d for d in range(len(x.shape)) if d != axiswise_dim)
# dim_to_reduce = tuple(d for d in range(len(x.shape)) if d != axiswise_dim)

amax = torch.amax(torch.abs(x), dim=dim_to_reduce, keepdim=True)
amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True)

# If the user asked for distributed reduction, do it.
# If the user did not ask for it, assume that it will
Expand Down
61 changes: 46 additions & 15 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
return True


class TestFloat8Tensor(unittest.TestCase):
class TestFloat8Tensor:
def test_preserves_dtype(self) -> None:
# hp means high precision, lp means low precision
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
Expand All @@ -73,7 +73,7 @@ def test_preserves_dtype(self) -> None:
x1_s = tensor_to_scale(x1_hp, lp_dtype)
x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype)
x3_hp = x2_lp.to_original_precision()
self.assertTrue(x3_hp.dtype == hp_dtype)
assert x3_hp.dtype == hp_dtype

def test_differentiable_casts(self) -> None:
lp_dtypes = (e4m3_dtype, e5m2_dtype)
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_index_put(self):
fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn)
fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn)

with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
b[index] = fp8_a
fp8_b[index] = a
fp8_b_bad[index] = fp8_a
Expand All @@ -122,7 +122,7 @@ def test_copy_(self):
b = torch.empty(16, dtype=torch.bfloat16)
b.copy_(fp8_a) # Should work
torch.testing.assert_close(b, fp8_a.to_original_precision())
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
fp8_a.copy_(b) # Should fail

fp8_b = Float8Tensor(
Expand All @@ -149,21 +149,49 @@ def test_weights_only_load(self):
buffer.seek(0)
_ = torch.load(buffer, weights_only=True)

def test_axiswise_dynamic_cast(self):
a = torch.randn(16, 32, dtype=torch.bfloat16)
@pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)])
@pytest.mark.parametrize("dim_name", ["first", "last"])
def test_axiswise_dynamic_cast(self, shape, dim_name):
a = torch.randn(*shape, dtype=torch.bfloat16)

if dim_name == "first":
dim = 0
elif dim_name == "last":
dim = len(a.shape) - 1

linear_mm_config = LinearMMConfig()
a_fp8 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=dim,
)
a_dq = a_fp8.to_original_precision()
sqnr = compute_error(a, a_dq)
assert sqnr >= 25.0

# TODO(next) make this work
def test_axiswise_reshape(self):
a = torch.randn(3, 5, 7, dtype=torch.bfloat16, device="cuda")
linear_mm_config = LinearMMConfig()

a_fp8 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=0,
)
# print(a_fp8)
# print(a_fp8.to_original_precision())
# print(a_fp8.t())
b = a_fp8.t()
# TODO check numerical accuracy
# a_fp8._data.shape is (3, 5, 7)
# a_fp8._scale.shape is (1, 5, 7)
print(a_fp8._scale.shape)

# reshape to (3, 5 * 7)
# a_fp8._scale.shape should be (1, 5 * 7)
a_fp8_r = a_fp8.reshape(3, -1)
print(a_fp8_r._scale.shape)


def test_axiswise_gemm(self):
a = torch.randn(16, 32, dtype=torch.bfloat16, device="cuda")
Expand All @@ -177,18 +205,21 @@ def test_axiswise_gemm(self):
linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=0,
axiswise_dim=1,
)
b_fp8 = hp_tensor_to_float8_dynamic(
b,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=0,
axiswise_dim=1,
)
c = torch.mm(a_fp8, b_fp8.t())
print(c)
c_fp8_compute = torch.mm(a_fp8, b_fp8.t())
print(c_fp8_compute)
c_ref = torch.mm(a, b.t())
sqnr = compute_error(c_ref, c_fp8_compute)
print('sqnr', sqnr)
# TODO check numerical accuracy


Expand Down

0 comments on commit c4c9ae8

Please sign in to comment.