diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 9df065bc..f0ba914f 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -14,8 +14,3 @@ # this doesn't work with autocast + torch.compile + FSDP. Enabling this # option is useful for safety, but not strictly necessary. enable_pre_and_post_forward = True - -# If True, dynamic linear uses hooks for activation casting -# TODO(before land): add test coverage for both cases -# dynamic_use_activation_hooks = True -# dynamic_use_activation_hooks = False diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 14f9fbc7..4a5ac2e1 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -8,7 +8,11 @@ """ import torch -from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd +from float8_experimental.float8_tensor import ( + Float8Tensor, + tensor_already_casted_to_fp8, + to_fp8_no_autograd, +) from float8_experimental.float8_utils import tensor_to_scale @@ -30,6 +34,9 @@ def forward( @staticmethod def backward(ctx, gradY): + if tensor_already_casted_to_fp8(gradY): + # check to early return if already casted to float8 + return gradY, None gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2) fp8_tensor = to_fp8_no_autograd( gradY, gradY_scale, torch.float8_e5m2, ctx.emulate @@ -37,43 +44,18 @@ def backward(ctx, gradY): return fp8_tensor, None -def cast_x_to_float8_e4m3fn_pre_hook(module, args): - """ - Hook to cast the incoming activation to `torch.float8_e4m3fn` - """ - return module.cast_to_float8_e4m3fn(args[0]) - - -def cast_grad_to_float8_e5m2_backward_forward_hook(module, input, output): - """This is a forward hook that sends the output of the model through - a no-op in the forward but a cast to float8_e5m2 in the backward. - - Args: - module (nn.Module): the module to cast the output of - input (Tensor): the input to the module forward call - output (Tensor): the output of the module forward - """ - return module.cast_to_float8_e5m2_bw(output) - - class Float8DynamicLinear(torch.nn.Linear): """ A wrapper around a `torch.nn.Linear` module which does fp8 compute. By on the fly conversion to fp8 of the input and weight tensors. """ - def __init__(self, use_activation_hooks: bool, **super_kwargs): - """ - Args: - use_activation_hooks (bool): whether to use activation hooks for casting to and from float8 - """ + def __init__(self, **super_kwargs): super().__init__(**super_kwargs) - self.use_activation_hooks = use_activation_hooks - def forward(self, x): # cast x to float8_e4m3fn if not using activation hooks - x_fp8 = x if self.use_activation_hooks else self.cast_to_float8_e4m3fn(x) + x_fp8 = self.cast_to_float8_e4m3fn(x) # cast w to float8_e4m3fn w_fp8 = self.cast_to_float8_e4m3fn(self.weight) @@ -81,12 +63,14 @@ def forward(self, x): y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) # Cast gradY to float8_e5m2 during backward if not using activation hooks - if not self.use_activation_hooks: - y = self.cast_to_float8_e5m2_bw(y) + y = self.cast_to_float8_e5m2_bw(y) return y def cast_to_float8_e4m3fn(self, inpt_tensor: torch.Tensor) -> Float8Tensor: + if tensor_already_casted_to_fp8(inpt_tensor): + # check to early return if already casted to float8 + return inpt_tensor scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn) return Float8Tensor.to_float8( inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate @@ -96,16 +80,13 @@ def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor: return NoopFwToFloat8E5M2Bw.apply(gradY, self.emulate) @classmethod - def from_float( - cls, mod, emulate: bool = False, use_activation_hooks: bool = False - ) -> "Float8DynamicLinear": + def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear": """ Create an nn.Linear with fp8 compute from a regular nn.Linear Args: mod (torch.nn.Linear): nn.Linear to convert emulate (bool): whether to emulate fp8 matmul logic in float32 - use_activation_hooks (bool): whether to use activation hooks for casting to and from float8 """ with torch.device("meta"): super_kwargs = { @@ -113,14 +94,8 @@ def from_float( "out_features": mod.out_features, "bias": False, } - new_mod = cls(use_activation_hooks, **super_kwargs) + new_mod = cls(**super_kwargs) new_mod.weight = mod.weight new_mod.bias = mod.bias new_mod.emulate = emulate - if new_mod.use_activation_hooks: - # install the hooks - new_mod.register_forward_pre_hook(cast_x_to_float8_e4m3fn_pre_hook) - new_mod.register_forward_hook( - cast_grad_to_float8_e5m2_backward_forward_hook - ) return new_mod diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 64a55652..4f6d277d 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -304,16 +304,15 @@ def forward(self, x): return y @classmethod - def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = False): + def from_float(cls, mod, emulate: bool = False): """ Create an nn.Linear with fp8 compute from a regular nn.Linear Args: mod (torch.nn.Linear): nn.Linear to convert emulate (bool): whether to emulate fp8 matmul logic in float32 - use_activation_hooks (bool): whether to use activation hooks instead of inlining the casting logic + cast_activation (bool): whether to use activation hooks instead of inlining the casting logic """ - assert not use_activation_hooks, "use_activation_hooks is not supported yet!" # TODO Follow up! This is a great idea but we need the mixin base to create real # Tensors and the Linear base to create empty params # with torch.device("meta"): diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index abd6cdbe..8568b516 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -33,14 +33,12 @@ def get_float8_linear( linear_type: LinearType, linear_ref: torch.nn.Linear, emulate: bool = False, - use_activation_hooks: bool = False, ): """Returns a Float8Linear module of the given type, initialized from linear_ref. Args: linear_type: The type of Float8Linear to return. linear_ref: The linear module to initialize from. emulate: Whether to emulate the fp8 matmul logic in float32. - use_activation_hooks: Whether to use activation hooks for dynamic linear. """ LINEAR_TYPE_MAP = { LinearType.DELAYED: Float8Linear, @@ -48,12 +46,9 @@ def get_float8_linear( } if linear_type not in LINEAR_TYPE_MAP: raise ValueError(f"linear_type must be one of {LINEAR_TYPE_MAP.keys()}") - if use_activation_hooks and linear_type != LinearType.DYNAMIC: - raise ValueError("use_activation_hooks is only supported for dynamic linear") return LINEAR_TYPE_MAP[linear_type].from_float( copy.deepcopy(linear_ref), emulate=emulate, - use_activation_hooks=use_activation_hooks, ) @@ -104,7 +99,6 @@ def swap_linear_with_float8_linear( *, skip_fqn_list: Optional[List[str]] = None, emulate: bool = False, - use_activation_hooks: bool = False, linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None, ) -> nn.Module: """ @@ -117,7 +111,6 @@ def swap_linear_with_float8_linear( skip_fqn_list (List[str], optional): If specified, a list of module FQNs to skip. Linear submodules of these skipped modules will also be skipped. emulate (bool): Whether to emulate the fp8 matmul logic in fp32. - use_activation_hooks (bool): Whether to cast activations to fp8 using module hooks. linear_layer_filter (Optional[Callable[[nn.Linear], bool]]): If specified, only the linear layers that pass the filter function will be swapped. """ @@ -129,9 +122,7 @@ def swap_linear_with_float8_linear( raise AssertionError( f"Does not support a root nn.Linear with children: {module}" ) - return module_cls.from_float( - module, emulate=emulate, use_activation_hooks=use_activation_hooks - ) + return module_cls.from_float(module, emulate=emulate) # Mark all modules to skip as visited root_module = module @@ -155,9 +146,7 @@ def post_order_traversal( assert ( parent_module is not None ), f"Linear root module should return early: {module}" - float8linear_module = module_cls.from_float( - module, emulate=emulate, use_activation_hooks=use_activation_hooks - ) + float8linear_module = module_cls.from_float(module, emulate=emulate) setattr(parent_module, module_name, float8linear_module) post_order_traversal(root_module, "", None) diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 3647e185..365e6093 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -7,13 +7,29 @@ import torch -from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated +import torch.distributed._functional_collectives as funcol +from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated from torch.distributed._tensor import DTensor aten = torch.ops.aten +def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: + """ + Check if the tensor is already casted to fp8 + """ + if isinstance(tensor, Float8Tensor): + return True + elif isinstance(tensor, DTensor): + # TODO: shall we stick to public API and directly use tensor.to_local() here? + return tensor_already_casted_to_fp8(tensor._local_tensor) + elif isinstance(tensor, funcol.AsyncCollectiveTensor): + return tensor_already_casted_to_fp8(tensor.elem) + + return False + + def to_fp8_no_autograd( x: torch.Tensor, x_scale: torch.Tensor, float8_dtype: torch.dtype, emulate: bool ) -> "Float8Tensor": diff --git a/test/conftest.py b/test/conftest.py deleted file mode 100644 index 30e42fd4..00000000 --- a/test/conftest.py +++ /dev/null @@ -1,24 +0,0 @@ -import pytest - - -@pytest.fixture -def x_fail_activation_hooks(request): - use_activation_hooks = request.getfixturevalue("use_activation_hooks") - if use_activation_hooks: - request.node.add_marker( - pytest.mark.xfail( - reason="use_activation_hooks is not supported for AOT", strict=True - ) - ) - - -@pytest.fixture -def x_fail_activation_hooks_with_delayed(request): - linear_type = request.getfixturevalue("linear_type") - use_activation_hooks = request.getfixturevalue("use_activation_hooks") - if use_activation_hooks and linear_type == linear_type.DELAYED: - request.node.add_marker( - pytest.mark.xfail( - reason="use_activation_hooks is not supported for AOT", strict=True - ) - ) diff --git a/test/test_base.py b/test/test_base.py index b43d57cf..8a8233d4 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -60,9 +60,8 @@ def _test_linear_impl( m_ref, linear_type: LinearType, emulate: bool, - use_activation_hooks: bool = False, ): - m_fp8 = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks) + m_fp8 = get_float8_linear(linear_type, m_ref, emulate) for _ in range(2): if linear_requires_sync(linear_type): sync_float8_amax_and_scale_history(m_fp8) @@ -123,15 +122,12 @@ def _test_linear_impl( @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) - @pytest.mark.parametrize("use_activation_hooks", [True, False]) - @pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed") @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_linear_nobias( self, x_shape, linear_type: LinearType, emulate: bool, - use_activation_hooks: bool, ): if not emulate: if not torch.cuda.is_available(): @@ -145,7 +141,7 @@ def test_linear_nobias( x = torch.randn(*x_shape, device="cuda") m_ref = nn.Linear(16, 32, bias=False, device="cuda") - self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks) + self._test_linear_impl(x, m_ref, linear_type, emulate) @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @@ -153,8 +149,6 @@ def test_linear_nobias( @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) - @pytest.mark.parametrize("use_activation_hooks", [True, False]) - @pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed") @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_linear_bias( self, @@ -162,7 +156,6 @@ def test_linear_bias( linear_type: LinearType, emulate: bool, linear_dtype: torch.dtype, - use_activation_hooks: bool, ): if not emulate: if not torch.cuda.is_available(): @@ -176,22 +169,19 @@ def test_linear_bias( x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) - self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks) + self._test_linear_impl(x, m_ref, linear_type, emulate) @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) - @pytest.mark.parametrize("use_activation_hooks", [True, False]) - @pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed") @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_autocast_outputs( self, linear_type: LinearType, emulate: bool, linear_dtype: torch.dtype, - use_activation_hooks: bool, ): if not emulate: if not torch.cuda.is_available(): @@ -204,7 +194,7 @@ def test_autocast_outputs( pytest.skip() m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) - m = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks) + m = get_float8_linear(linear_type, m_ref, emulate) # autocast off x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) @@ -242,7 +232,7 @@ def test_type_cast( ) m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) - m = get_float8_linear(linear_type, m, emulate, False) + m = get_float8_linear(linear_type, m, emulate) # Cast the module to dtype m = m.to(dtype=linear_dtype) diff --git a/test/test_compile.py b/test/test_compile.py index eadae864..2a9abba9 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -35,7 +35,6 @@ def _test_compile_base( emulate: bool, linear_type: LinearType, dtype: torch.dtype, - use_activation_hooks: bool, ): random.seed(0) torch.manual_seed(0) @@ -45,7 +44,7 @@ def _test_compile_base( x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) - m_fp8 = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks) + m_fp8 = get_float8_linear(linear_type, m_ref, emulate) m_fp8 = torch.compile(m_fp8, backend=backend, fullgraph=fullgraph) m_ref = torch.compile(m_ref, backend=backend, fullgraph=fullgraph) @@ -64,52 +63,35 @@ def _test_compile_base( @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) -@pytest.mark.parametrize("use_activation_hooks", [False, True]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_eager_only( fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtype, - use_activation_hooks: bool, ): - if linear_type == LinearType.DELAYED and use_activation_hooks: - pytest.skip("use_activation_hooks is only supported for dynamic linear") torch._dynamo.reset() - _test_compile_base( - "eager", fullgraph, emulate, linear_type, dtype, use_activation_hooks - ) + _test_compile_base("eager", fullgraph, emulate, linear_type, dtype) @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) -@pytest.mark.parametrize("use_activation_hooks", [False, True]) -# TODO this shouldn't fail but multiple fake modes -@pytest.mark.usefixtures("x_fail_activation_hooks") @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_aot_eager( fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtype, - use_activation_hooks: bool, ): - if linear_type == LinearType.DELAYED and use_activation_hooks: - pytest.skip("use_activation_hooks is only supported for dynamic linear") torch._dynamo.reset() - _test_compile_base( - "aot_eager", fullgraph, emulate, linear_type, dtype, use_activation_hooks - ) + _test_compile_base("aot_eager", fullgraph, emulate, linear_type, dtype) @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False]) @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) -@pytest.mark.parametrize("use_activation_hooks", [False, True]) -# TODO this shouldn't fail but multiple fake modes -@pytest.mark.usefixtures("x_fail_activation_hooks") @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA not available") @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_inductor( @@ -117,14 +99,9 @@ def test_inductor( emulate: bool, linear_type: bool, dtype: torch.dtype, - use_activation_hooks: bool, ): - if linear_type == LinearType.DELAYED and use_activation_hooks: - pytest.skip("use_activation_hooks is only supported for dynamic linear") torch._dynamo.reset() - _test_compile_base( - "inductor", fullgraph, emulate, linear_type, dtype, use_activation_hooks - ) + _test_compile_base("inductor", fullgraph, emulate, linear_type, dtype) class TestGraphBreaks(DynamoTestCase):