Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

early return by check tensor already casted or not #233

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,3 @@
# this doesn't work with autocast + torch.compile + FSDP. Enabling this
# option is useful for safety, but not strictly necessary.
enable_pre_and_post_forward = True

# If True, dynamic linear uses hooks for activation casting
# TODO(before land): add test coverage for both cases
# dynamic_use_activation_hooks = True
# dynamic_use_activation_hooks = False
57 changes: 16 additions & 41 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
"""
import torch

from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd
from float8_experimental.float8_tensor import (
Float8Tensor,
tensor_already_casted_to_fp8,
to_fp8_no_autograd,
)
from float8_experimental.float8_utils import tensor_to_scale


Expand All @@ -30,63 +34,43 @@ def forward(

@staticmethod
def backward(ctx, gradY):
if tensor_already_casted_to_fp8(gradY):
Copy link
Contributor Author

@wanchaol wanchaol Mar 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for cast_to_float8_e5m2_bw, unfortunately I can't do a forward check only and have to check the backward gradients to see if it's already casted, as the forward only have y but not grad_y

# check to early return if already casted to float8
return gradY, None
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
fp8_tensor = to_fp8_no_autograd(
gradY, gradY_scale, torch.float8_e5m2, ctx.emulate
)
return fp8_tensor, None


def cast_x_to_float8_e4m3fn_pre_hook(module, args):
"""
Hook to cast the incoming activation to `torch.float8_e4m3fn`
"""
return module.cast_to_float8_e4m3fn(args[0])


def cast_grad_to_float8_e5m2_backward_forward_hook(module, input, output):
"""This is a forward hook that sends the output of the model through
a no-op in the forward but a cast to float8_e5m2 in the backward.

Args:
module (nn.Module): the module to cast the output of
input (Tensor): the input to the module forward call
output (Tensor): the output of the module forward
"""
return module.cast_to_float8_e5m2_bw(output)


class Float8DynamicLinear(torch.nn.Linear):
"""
A wrapper around a `torch.nn.Linear` module which does fp8 compute. By on the fly
conversion to fp8 of the input and weight tensors.
"""

def __init__(self, use_activation_hooks: bool, **super_kwargs):
"""
Args:
use_activation_hooks (bool): whether to use activation hooks for casting to and from float8
"""
def __init__(self, **super_kwargs):
super().__init__(**super_kwargs)

self.use_activation_hooks = use_activation_hooks

def forward(self, x):
# cast x to float8_e4m3fn if not using activation hooks
x_fp8 = x if self.use_activation_hooks else self.cast_to_float8_e4m3fn(x)
x_fp8 = self.cast_to_float8_e4m3fn(x)

# cast w to float8_e4m3fn
w_fp8 = self.cast_to_float8_e4m3fn(self.weight)

y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)

# Cast gradY to float8_e5m2 during backward if not using activation hooks
if not self.use_activation_hooks:
y = self.cast_to_float8_e5m2_bw(y)
y = self.cast_to_float8_e5m2_bw(y)

return y

def cast_to_float8_e4m3fn(self, inpt_tensor: torch.Tensor) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
# check to early return if already casted to float8
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
return Float8Tensor.to_float8(
inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate
Expand All @@ -96,31 +80,22 @@ def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor:
return NoopFwToFloat8E5M2Bw.apply(gradY, self.emulate)

@classmethod
def from_float(
cls, mod, emulate: bool = False, use_activation_hooks: bool = False
) -> "Float8DynamicLinear":
def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear

Args:
mod (torch.nn.Linear): nn.Linear to convert
emulate (bool): whether to emulate fp8 matmul logic in float32
use_activation_hooks (bool): whether to use activation hooks for casting to and from float8
"""
with torch.device("meta"):
super_kwargs = {
"in_features": mod.in_features,
"out_features": mod.out_features,
"bias": False,
}
new_mod = cls(use_activation_hooks, **super_kwargs)
new_mod = cls(**super_kwargs)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
new_mod.emulate = emulate
if new_mod.use_activation_hooks:
# install the hooks
new_mod.register_forward_pre_hook(cast_x_to_float8_e4m3fn_pre_hook)
new_mod.register_forward_hook(
cast_grad_to_float8_e5m2_backward_forward_hook
)
return new_mod
5 changes: 2 additions & 3 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,16 +304,15 @@ def forward(self, x):
return y

@classmethod
def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = False):
def from_float(cls, mod, emulate: bool = False):
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear

Args:
mod (torch.nn.Linear): nn.Linear to convert
emulate (bool): whether to emulate fp8 matmul logic in float32
use_activation_hooks (bool): whether to use activation hooks instead of inlining the casting logic
cast_activation (bool): whether to use activation hooks instead of inlining the casting logic
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This looks like it should be removed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was removed in the next pr

"""
assert not use_activation_hooks, "use_activation_hooks is not supported yet!"
# TODO Follow up! This is a great idea but we need the mixin base to create real
# Tensors and the Linear base to create empty params
# with torch.device("meta"):
Expand Down
15 changes: 2 additions & 13 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,22 @@ def get_float8_linear(
linear_type: LinearType,
linear_ref: torch.nn.Linear,
emulate: bool = False,
use_activation_hooks: bool = False,
):
"""Returns a Float8Linear module of the given type, initialized from linear_ref.
Args:
linear_type: The type of Float8Linear to return.
linear_ref: The linear module to initialize from.
emulate: Whether to emulate the fp8 matmul logic in float32.
use_activation_hooks: Whether to use activation hooks for dynamic linear.
"""
LINEAR_TYPE_MAP = {
LinearType.DELAYED: Float8Linear,
LinearType.DYNAMIC: Float8DynamicLinear,
}
if linear_type not in LINEAR_TYPE_MAP:
raise ValueError(f"linear_type must be one of {LINEAR_TYPE_MAP.keys()}")
if use_activation_hooks and linear_type != LinearType.DYNAMIC:
raise ValueError("use_activation_hooks is only supported for dynamic linear")
return LINEAR_TYPE_MAP[linear_type].from_float(
copy.deepcopy(linear_ref),
emulate=emulate,
use_activation_hooks=use_activation_hooks,
)


Expand Down Expand Up @@ -104,7 +99,6 @@ def swap_linear_with_float8_linear(
*,
skip_fqn_list: Optional[List[str]] = None,
emulate: bool = False,
use_activation_hooks: bool = False,
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
) -> nn.Module:
"""
Expand All @@ -117,7 +111,6 @@ def swap_linear_with_float8_linear(
skip_fqn_list (List[str], optional): If specified, a list of module FQNs to skip.
Linear submodules of these skipped modules will also be skipped.
emulate (bool): Whether to emulate the fp8 matmul logic in fp32.
use_activation_hooks (bool): Whether to cast activations to fp8 using module hooks.
linear_layer_filter (Optional[Callable[[nn.Linear], bool]]): If specified, only the linear layers
that pass the filter function will be swapped.
"""
Expand All @@ -129,9 +122,7 @@ def swap_linear_with_float8_linear(
raise AssertionError(
f"Does not support a root nn.Linear with children: {module}"
)
return module_cls.from_float(
module, emulate=emulate, use_activation_hooks=use_activation_hooks
)
return module_cls.from_float(module, emulate=emulate)

# Mark all modules to skip as visited
root_module = module
Expand All @@ -155,9 +146,7 @@ def post_order_traversal(
assert (
parent_module is not None
), f"Linear root module should return early: {module}"
float8linear_module = module_cls.from_float(
module, emulate=emulate, use_activation_hooks=use_activation_hooks
)
float8linear_module = module_cls.from_float(module, emulate=emulate)
setattr(parent_module, module_name, float8linear_module)

post_order_traversal(root_module, "", None)
Expand Down
18 changes: 17 additions & 1 deletion float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,29 @@

import torch

from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated
import torch.distributed._functional_collectives as funcol

from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated
from torch.distributed._tensor import DTensor

aten = torch.ops.aten


def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool:
"""
Check if the tensor is already casted to fp8
"""
if isinstance(tensor, Float8Tensor):
return True
elif isinstance(tensor, DTensor):
# TODO: shall we stick to public API and directly use tensor.to_local() here?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if in general, subclasses composing with other subclasses should have a generic way to determine the nested types

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'll also need to think about how this should behave, this seems in general how we can get subclass hierarchy from some type info, simply trying out type.mro or inspect.getmro seems not working for me. Maybe need to check with @ezyang @bdhirsh to see if there's any better suggestions

return tensor_already_casted_to_fp8(tensor._local_tensor)
elif isinstance(tensor, funcol.AsyncCollectiveTensor):
return tensor_already_casted_to_fp8(tensor.elem)

return False


def to_fp8_no_autograd(
x: torch.Tensor, x_scale: torch.Tensor, float8_dtype: torch.dtype, emulate: bool
) -> "Float8Tensor":
Expand Down
24 changes: 0 additions & 24 deletions test/conftest.py

This file was deleted.

20 changes: 5 additions & 15 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,8 @@ def _test_linear_impl(
m_ref,
linear_type: LinearType,
emulate: bool,
use_activation_hooks: bool = False,
):
m_fp8 = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks)
m_fp8 = get_float8_linear(linear_type, m_ref, emulate)
for _ in range(2):
if linear_requires_sync(linear_type):
sync_float8_amax_and_scale_history(m_fp8)
Expand Down Expand Up @@ -123,15 +122,12 @@ def _test_linear_impl(
@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
@pytest.mark.parametrize("use_activation_hooks", [True, False])
@pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_linear_nobias(
self,
x_shape,
linear_type: LinearType,
emulate: bool,
use_activation_hooks: bool,
):
if not emulate:
if not torch.cuda.is_available():
Expand All @@ -145,24 +141,21 @@ def test_linear_nobias(

x = torch.randn(*x_shape, device="cuda")
m_ref = nn.Linear(16, 32, bias=False, device="cuda")
self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks)
self._test_linear_impl(x, m_ref, linear_type, emulate)

@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
@pytest.mark.parametrize(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
@pytest.mark.parametrize("use_activation_hooks", [True, False])
@pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_linear_bias(
self,
x_shape,
linear_type: LinearType,
emulate: bool,
linear_dtype: torch.dtype,
use_activation_hooks: bool,
):
if not emulate:
if not torch.cuda.is_available():
Expand All @@ -176,22 +169,19 @@ def test_linear_bias(

x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype)
self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks)
self._test_linear_impl(x, m_ref, linear_type, emulate)

@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
@pytest.mark.parametrize(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
@pytest.mark.parametrize("use_activation_hooks", [True, False])
@pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_autocast_outputs(
self,
linear_type: LinearType,
emulate: bool,
linear_dtype: torch.dtype,
use_activation_hooks: bool,
):
if not emulate:
if not torch.cuda.is_available():
Expand All @@ -204,7 +194,7 @@ def test_autocast_outputs(
pytest.skip()

m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
m = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks)
m = get_float8_linear(linear_type, m_ref, emulate)

# autocast off
x = torch.randn(16, 32, device="cuda", dtype=linear_dtype)
Expand Down Expand Up @@ -242,7 +232,7 @@ def test_type_cast(
)

m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
m = get_float8_linear(linear_type, m, emulate, False)
m = get_float8_linear(linear_type, m, emulate)

# Cast the module to dtype
m = m.to(dtype=linear_dtype)
Expand Down
Loading