Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
[wip] add axiswise granularity to Float8Tensor
Browse files Browse the repository at this point in the history
Summary:

This PR adds the axiswise scaling granularity to `Float8Tensor` and
ensures that basic ops like transpose and `torch._scaled_mm` work as
expected.

A future PR will add integration with `Float8Linear`.

Test Plan:

TODO

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
vkuzo committed Jul 26, 2024
1 parent 3b786be commit e87f005
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 16 deletions.
12 changes: 12 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ def short_str(self):
return "dyn"


class ScalingGranularity(enum.Enum):
"""
Defines the granularity of scaling strategies for casting to float8
"""

# A single scaling factor for the entire tensor
TENSORWISE = "tensorwise"
# Scaling factors computed along one axis of the tensor, reducing it to
# size 1.
AXISWISE = "axiswise"


@dataclass(frozen=True)
class CastConfig:
"""
Expand Down
40 changes: 38 additions & 2 deletions float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@
FLOAT8_OPS_TABLE: Dict[Any, Any] = {}


def _assert_tensorwise_scale(aten_op, scale):
assert (
# TODO(future PR): figure out why tensorwise scaling can have
# both rank 0 and rank 1
len(scale.shape)
in (0, 1)
), f"{aten_op} with axiswise scaling is not supported yet"


def implements(aten_ops):
"""Register aten ops to the float8 op table"""

Expand All @@ -34,16 +43,15 @@ def decorator(func):
[
aten.view.default,
aten._unsafe_view.default,
aten.t.default,
aten.as_strided.default,
aten.clone.default,
aten.detach.default,
aten.slice.Tensor,
aten.transpose.int,
aten.fill_.Scalar,
]
)
def float8_desugar_op(aten_op, args, kwargs=None):
_assert_tensorwise_scale(aten_op, args[0]._scale)
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
return Float8Tensor(
new_data,
Expand All @@ -54,8 +62,27 @@ def float8_desugar_op(aten_op, args, kwargs=None):
)


@implements(
[
aten.t.default,
aten.transpose.int,
]
)
def float8_desugar_data_and_scale(aten_op, args, kwargs=None):
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
new_scale = aten_op(args[0]._scale, *args[1:], **kwargs)
return Float8Tensor(
new_data,
new_scale,
args[0]._orig_dtype,
args[0]._linear_mm_config,
args[0]._gemm_input_role,
)


@implements([aten.split.Tensor])
def float8_split(aten_op, args, kwargs=None):
_assert_tensorwise_scale(aten_op, args[0]._scale)
new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs)

def make_float8(data):
Expand Down Expand Up @@ -101,6 +128,7 @@ def float8_cat(aten_op, args, kwargs=None):
assert (
chunk._gemm_input_role is gemm_input_role
), "Expecting all chunks to have the same gemm_input_role as a result of a split"
_assert_tensorwise_scale(aten_op, chunk._scale)
chunk_data.append(chunk._data.view(torch.uint8))

new_data = aten_op(chunk_data, *args[1:], **kwargs)
Expand All @@ -117,6 +145,7 @@ def float8_cast_up_op(aten_op, args, kwargs=None):
"addmm" -> out
"hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut"
"""
_assert_tensorwise_scale(aten_op, args[0]._scale)

def unwrap(x):
if isinstance(x, Float8Tensor):
Expand Down Expand Up @@ -229,6 +258,7 @@ def float8_addmm(aten_op, args, kwargs=None):

@implements([aten.is_same_size.default])
def float8_is_same_size(aten_op, args, kwargs=None):
_assert_tensorwise_scale(aten_op, args[0]._scale)
return args[0].shape == args[1].shape


Expand All @@ -238,6 +268,7 @@ def autocast_to_copy(aten_op, args, kwargs=None):
when the input is a Float8Tensor, presenting as a fp32
tensor.
"""
_assert_tensorwise_scale(aten_op, args[0]._scale)
assert isinstance(args[0], Float8Tensor)
assert (
len(kwargs) == 1 and "dtype" in kwargs
Expand Down Expand Up @@ -265,6 +296,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
"""
override funcol with FP8 handling
"""
_assert_tensorwise_scale(aten_op, args[0]._scale)
fp8_input = args[0]
assert isinstance(
fp8_input, Float8Tensor
Expand All @@ -284,6 +316,7 @@ def allgather_fp8(aten_op, args, kwargs=None):

@implements([c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default])
def wait_tensor_fp8(aten_op, args, kwargs=None):
_assert_tensorwise_scale(aten_op, args[0]._scale)
fp8_input = args[0]
assert isinstance(fp8_input, Float8Tensor)

Expand All @@ -304,6 +337,7 @@ def index_put_fp8(aten_op, args, kwargs=None):
fp8_values = args[2]
assert isinstance(fp8_self, Float8Tensor)
assert isinstance(fp8_values, Float8Tensor)
_assert_tensorwise_scale(fp8_self, args[0]._scale)
assert fp8_self._scale == fp8_values._scale
assert fp8_self.dtype == fp8_values.dtype
assert fp8_self._orig_dtype == fp8_values._orig_dtype
Expand Down Expand Up @@ -334,8 +368,10 @@ def copy_fp8(aten_op, args, kwargs=None):

if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
src_hp = src.to_original_precision()
_assert_tensorwise_scale(aten_op, src._scale)
return aten_op(self, src_hp, *args[2:], **kwargs)
elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
_assert_tensorwise_scale(aten_op, src._scale)
assert (
self._orig_dtype == src._orig_dtype
), "Expecting both Float8Tensors to be of the same dtype"
Expand Down
8 changes: 8 additions & 0 deletions float8_experimental/float8_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ def addmm_float8_unwrapped(
"""
a_inverse_scale = a_scale.reciprocal()
b_inverse_scale = b_scale.reciprocal()

# TODO: should we change torch._scaled_mm?
# torch._scaled_mm expects rowwise scaled scales to be of rank 1, not rank
# 2. Translate to this format.
# TODO: audit if we need to make this more generic for various shapes.
a_inverse_scale = a_inverse_scale.squeeze()
b_inverse_scale = b_inverse_scale.squeeze()

if output_dtype == torch.float32 and bias is not None:
# Bias is not supported by _scaled_mm when output is fp32
output = torch._scaled_mm(
Expand Down
14 changes: 13 additions & 1 deletion float8_experimental/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import torch

from float8_experimental.config import ScalingGranularity

from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand All @@ -36,6 +38,8 @@ def hp_tensor_to_float8_dynamic(
linear_mm_config: LinearMMConfig,
reduce_amax: bool = False,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
axiswise_dim: Optional[int] = None,
) -> Float8Tensor:
"""
Given a high precision tensor `hp_tensor`,
Expand All @@ -49,10 +53,18 @@ def hp_tensor_to_float8_dynamic(
reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
the 3 fwd/bwd gemms of linear
scaling_granularity: Defines the scaling granularity
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
"""
if tensor_already_casted_to_fp8(hp_tensor):
return hp_tensor
scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax)
scale = tensor_to_scale(
hp_tensor,
float8_dtype,
reduce_amax,
scaling_granularity,
axiswise_dim,
)
return hp_tensor_and_scale_to_float8(
hp_tensor,
scale,
Expand Down
13 changes: 6 additions & 7 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,12 @@ class Float8Tensor(torch.Tensor):
* `_data`: the underlying e4m3 or e5m2 data
* `_scale`: the scale used to scale the original fp32 tensor. We multiply
by scale to go from fp32 range to fp8 range, and divide by scale to go
from fp8 range to fp32 range.
from fp8 range to fp32 range. Scale is guaranteed to have a shape compatible
with `_data`. For example:
- if scaling is tensorwise, `_scale` is a scalar tensor
- if scaling is axiswise and _data.shape is [3, 5], `_scale` could have
shape [1, 5] or [5, 1]. The dim of the non-one entry defines the scaling
axis.
* `_orig_dtype`: the original dtype of the tensor used to create this
tensor.
* `_emulate`: if true using fp32 emulation for the matmuls, helpful
Expand Down Expand Up @@ -279,12 +284,6 @@ def __new__(
linear_mm_config: Optional[LinearMMConfig],
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
assert (
scale.numel() == 1
), "Scale should contain a single value, but got: {} elements".format(
scale.numel()
)

self = torch.Tensor._make_wrapper_subclass(
cls,
data.size(),
Expand Down
30 changes: 25 additions & 5 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Iterable, Literal, Tuple, Union
from typing import Iterable, Literal, Optional, Tuple, Union

import float8_experimental.config as config

import torch
import torch.distributed as dist
from float8_experimental.config import ScalingGranularity

# Helpful visualizer for debugging (only supports fp32):
# https://www.h-schmidt.net/FloatConverter/IEEE754.html
Expand Down Expand Up @@ -100,8 +101,23 @@ def amax_history_to_scale_stack(


@torch.no_grad()
def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:
amax = torch.max(torch.abs(x))
def tensor_to_amax(
x: torch.Tensor,
reduce_amax: bool = False,
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
axiswise_dim: Optional[int] = None,
) -> torch.Tensor:
if scaling_granularity is ScalingGranularity.TENSORWISE:
amax = torch.max(torch.abs(x))
else:
assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported"
assert axiswise_dim is not None, "unsupported"

# convert from axiswise_dim (dim to keep) to
# dim as the input to the `torch.amax` function (tuple of dims to reduce)
dim_to_reduce = tuple(d for d in range(len(x.shape)) if d != axiswise_dim)

amax = torch.amax(torch.abs(x), dim=dim_to_reduce, keepdim=True)

# If the user asked for distributed reduction, do it.
# If the user did not ask for it, assume that it will
Expand All @@ -114,9 +130,13 @@ def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:

@torch.no_grad()
def tensor_to_scale(
x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False
x: torch.Tensor,
float8_dtype: torch.dtype,
reduce_amax: bool = False,
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
axiswise_dim: Optional[int] = None,
) -> torch.Tensor:
amax = tensor_to_amax(x, reduce_amax=reduce_amax)
amax = tensor_to_amax(x, reduce_amax, scaling_granularity, axiswise_dim)
return amax_to_scale(amax, float8_dtype, x.dtype)


Expand Down
50 changes: 49 additions & 1 deletion test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,20 @@
import torch
import torch.nn as nn

from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType
from float8_experimental.config import (
CastConfig,
Float8LinearConfig,
ScalingGranularity,
ScalingType,
)
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_python_api import addmm_float8_unwrapped
from float8_experimental.float8_scaling_utils import hp_tensor_to_float8_dynamic
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand Down Expand Up @@ -143,6 +149,48 @@ def test_weights_only_load(self):
buffer.seek(0)
_ = torch.load(buffer, weights_only=True)

def test_axiswise_dynamic_cast(self):
a = torch.randn(16, 32, dtype=torch.bfloat16)
linear_mm_config = LinearMMConfig()
a_fp8 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=0,
)
# print(a_fp8)
# print(a_fp8.to_original_precision())
# print(a_fp8.t())
b = a_fp8.t()
# TODO check numerical accuracy

def test_axiswise_gemm(self):
a = torch.randn(16, 32, dtype=torch.bfloat16, device="cuda")
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")

linear_mm_config = LinearMMConfig()

a_fp8 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=0,
)
b_fp8 = hp_tensor_to_float8_dynamic(
b,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=0,
)
c = torch.mm(a_fp8, b_fp8.t())
print(c)
# TODO check numerical accuracy


class TestFloat8Linear:
def _test_linear_impl(
Expand Down

0 comments on commit e87f005

Please sign in to comment.