From a104f02c42c612b2aa5b5c3f4016b5a10693aafa Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Sun, 3 Mar 2024 21:26:09 -0800 Subject: [PATCH 1/6] rename use_activation_hooks to cast_activation as titled, it turns out we don't need to install additional hooks base on our TP + FP8 design. The only thing we need to do here is to be able to turn off activation casting, so that we can put activation casting in the TP hooks So renaming the flag to cast_activation instead and delete relevant tests --- float8_experimental/config.py | 5 ---- float8_experimental/float8_dynamic_linear.py | 22 +++++++---------- float8_experimental/float8_linear.py | 6 ++--- float8_experimental/float8_linear_utils.py | 18 +++++++------- test/test_base.py | 20 ++++------------ test/test_compile.py | 25 ++++---------------- 6 files changed, 29 insertions(+), 67 deletions(-) diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 9df065bc..f0ba914f 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -14,8 +14,3 @@ # this doesn't work with autocast + torch.compile + FSDP. Enabling this # option is useful for safety, but not strictly necessary. enable_pre_and_post_forward = True - -# If True, dynamic linear uses hooks for activation casting -# TODO(before land): add test coverage for both cases -# dynamic_use_activation_hooks = True -# dynamic_use_activation_hooks = False diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 14f9fbc7..66f55895 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -62,18 +62,18 @@ class Float8DynamicLinear(torch.nn.Linear): conversion to fp8 of the input and weight tensors. """ - def __init__(self, use_activation_hooks: bool, **super_kwargs): + def __init__(self, cast_activation: bool, **super_kwargs): """ Args: - use_activation_hooks (bool): whether to use activation hooks for casting to and from float8 + cast_activation (bool): whether to do activation casting to and from float8 """ super().__init__(**super_kwargs) - self.use_activation_hooks = use_activation_hooks + self.cast_activation = cast_activation def forward(self, x): # cast x to float8_e4m3fn if not using activation hooks - x_fp8 = x if self.use_activation_hooks else self.cast_to_float8_e4m3fn(x) + x_fp8 = self.cast_to_float8_e4m3fn(x) if self.cast_activation else x # cast w to float8_e4m3fn w_fp8 = self.cast_to_float8_e4m3fn(self.weight) @@ -81,7 +81,7 @@ def forward(self, x): y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) # Cast gradY to float8_e5m2 during backward if not using activation hooks - if not self.use_activation_hooks: + if self.cast_activation: y = self.cast_to_float8_e5m2_bw(y) return y @@ -97,7 +97,7 @@ def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor: @classmethod def from_float( - cls, mod, emulate: bool = False, use_activation_hooks: bool = False + cls, mod, emulate: bool = False, cast_activation: bool = True ) -> "Float8DynamicLinear": """ Create an nn.Linear with fp8 compute from a regular nn.Linear @@ -105,7 +105,7 @@ def from_float( Args: mod (torch.nn.Linear): nn.Linear to convert emulate (bool): whether to emulate fp8 matmul logic in float32 - use_activation_hooks (bool): whether to use activation hooks for casting to and from float8 + cast_activation (bool): whether to do activation casting to and from float8 """ with torch.device("meta"): super_kwargs = { @@ -113,14 +113,8 @@ def from_float( "out_features": mod.out_features, "bias": False, } - new_mod = cls(use_activation_hooks, **super_kwargs) + new_mod = cls(cast_activation, **super_kwargs) new_mod.weight = mod.weight new_mod.bias = mod.bias new_mod.emulate = emulate - if new_mod.use_activation_hooks: - # install the hooks - new_mod.register_forward_pre_hook(cast_x_to_float8_e4m3fn_pre_hook) - new_mod.register_forward_hook( - cast_grad_to_float8_e5m2_backward_forward_hook - ) return new_mod diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 64a55652..ede7e208 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -304,16 +304,16 @@ def forward(self, x): return y @classmethod - def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = False): + def from_float(cls, mod, emulate: bool = False, cast_activation: bool = True): """ Create an nn.Linear with fp8 compute from a regular nn.Linear Args: mod (torch.nn.Linear): nn.Linear to convert emulate (bool): whether to emulate fp8 matmul logic in float32 - use_activation_hooks (bool): whether to use activation hooks instead of inlining the casting logic + cast_activation (bool): whether to use activation hooks instead of inlining the casting logic """ - assert not use_activation_hooks, "use_activation_hooks is not supported yet!" + assert cast_activation is True, "cast activation option is not supported yet, we always cast activations!" # TODO Follow up! This is a great idea but we need the mixin base to create real # Tensors and the Linear base to create empty params # with torch.device("meta"): diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index ce3d4692..1529f983 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -33,14 +33,14 @@ def get_float8_linear( linear_type: LinearType, linear_ref: torch.nn.Linear, emulate: bool = False, - use_activation_hooks: bool = False, + cast_activation: bool = True, ): """Returns a Float8Linear module of the given type, initialized from linear_ref. Args: linear_type: The type of Float8Linear to return. linear_ref: The linear module to initialize from. emulate: Whether to emulate the fp8 matmul logic in float32. - use_activation_hooks: Whether to use activation hooks for dynamic linear. + cast_activation: Whether to use activation hooks for dynamic linear. """ LINEAR_TYPE_MAP = { LinearType.DELAYED: Float8Linear, @@ -48,12 +48,12 @@ def get_float8_linear( } if linear_type not in LINEAR_TYPE_MAP: raise ValueError(f"linear_type must be one of {LINEAR_TYPE_MAP.keys()}") - if use_activation_hooks and linear_type != LinearType.DYNAMIC: - raise ValueError("use_activation_hooks is only supported for dynamic linear") + if not cast_activation and linear_type != LinearType.DYNAMIC: + raise ValueError("cast_activation option is only supported for dynamic linear") return LINEAR_TYPE_MAP[linear_type].from_float( copy.deepcopy(linear_ref), emulate=emulate, - use_activation_hooks=use_activation_hooks, + cast_activation=cast_activation, ) @@ -90,7 +90,7 @@ def swap_linear_with_float8_linear( *, skip_fqn_list: Optional[List[str]] = None, emulate: bool = False, - use_activation_hooks: bool = False, + cast_activation: bool = True, ) -> nn.Module: """ Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances @@ -102,7 +102,7 @@ def swap_linear_with_float8_linear( skip_fqn_list (List[str], optional): If specified, a list of module FQNs to skip. Linear submodules of these skipped modules will also be skipped. emulate (bool): Whether to emulate the fp8 matmul logic in fp32. - use_activation_hooks (bool): Whether to cast activations to fp8 using module hooks. + cast_activation (bool): Whether to cast activations to fp8 using module hooks. """ module_names_to_skip = set(skip_fqn_list or []) if isinstance(module, nn.Linear): @@ -111,7 +111,7 @@ def swap_linear_with_float8_linear( f"Does not support a root nn.Linear with children: {module}" ) return module_cls.from_float( - module, emulate=emulate, use_activation_hooks=use_activation_hooks + module, emulate=emulate, cast_activation=cast_activation ) # Mark all modules to skip as visited @@ -135,7 +135,7 @@ def post_order_traversal( parent_module is not None ), f"Linear root module should return early: {module}" float8linear_module = module_cls.from_float( - module, emulate=emulate, use_activation_hooks=use_activation_hooks + module, emulate=emulate, cast_activation=cast_activation ) setattr(parent_module, module_name, float8linear_module) diff --git a/test/test_base.py b/test/test_base.py index a53deb91..eb68f1f4 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -58,9 +58,8 @@ def _test_linear_impl( m_ref, linear_type: LinearType, emulate: bool, - use_activation_hooks: bool = False, ): - m_fp8 = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks) + m_fp8 = get_float8_linear(linear_type, m_ref, emulate) for _ in range(2): if linear_requires_sync(linear_type): sync_float8_amax_and_scale_history(m_fp8) @@ -121,15 +120,12 @@ def _test_linear_impl( @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) - @pytest.mark.parametrize("use_activation_hooks", [True, False]) - @pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed") @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_linear_nobias( self, x_shape, linear_type: LinearType, emulate: bool, - use_activation_hooks: bool, ): if not emulate: if not torch.cuda.is_available(): @@ -143,7 +139,7 @@ def test_linear_nobias( x = torch.randn(*x_shape, device="cuda") m_ref = nn.Linear(16, 32, bias=False, device="cuda") - self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks) + self._test_linear_impl(x, m_ref, linear_type, emulate) @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @@ -151,8 +147,6 @@ def test_linear_nobias( @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) - @pytest.mark.parametrize("use_activation_hooks", [True, False]) - @pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed") @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_linear_bias( self, @@ -160,7 +154,6 @@ def test_linear_bias( linear_type: LinearType, emulate: bool, linear_dtype: torch.dtype, - use_activation_hooks: bool, ): if not emulate: if not torch.cuda.is_available(): @@ -174,22 +167,19 @@ def test_linear_bias( x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) - self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks) + self._test_linear_impl(x, m_ref, linear_type, emulate) @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) - @pytest.mark.parametrize("use_activation_hooks", [True, False]) - @pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed") @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_autocast_outputs( self, linear_type: LinearType, emulate: bool, linear_dtype: torch.dtype, - use_activation_hooks: bool, ): if not emulate: if not torch.cuda.is_available(): @@ -202,7 +192,7 @@ def test_autocast_outputs( pytest.skip() m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) - m = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks) + m = get_float8_linear(linear_type, m_ref, emulate) # autocast off x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) @@ -240,7 +230,7 @@ def test_type_cast( ) m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) - m = get_float8_linear(linear_type, m, emulate, False) + m = get_float8_linear(linear_type, m, emulate) # Cast the module to dtype m = m.to(dtype=linear_dtype) diff --git a/test/test_compile.py b/test/test_compile.py index eadae864..2f7a83e3 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -35,7 +35,6 @@ def _test_compile_base( emulate: bool, linear_type: LinearType, dtype: torch.dtype, - use_activation_hooks: bool, ): random.seed(0) torch.manual_seed(0) @@ -45,7 +44,7 @@ def _test_compile_base( x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) - m_fp8 = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks) + m_fp8 = get_float8_linear(linear_type, m_ref, emulate) m_fp8 = torch.compile(m_fp8, backend=backend, fullgraph=fullgraph) m_ref = torch.compile(m_ref, backend=backend, fullgraph=fullgraph) @@ -64,20 +63,16 @@ def _test_compile_base( @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) -@pytest.mark.parametrize("use_activation_hooks", [False, True]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_eager_only( fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtype, - use_activation_hooks: bool, ): - if linear_type == LinearType.DELAYED and use_activation_hooks: - pytest.skip("use_activation_hooks is only supported for dynamic linear") torch._dynamo.reset() _test_compile_base( - "eager", fullgraph, emulate, linear_type, dtype, use_activation_hooks + "eager", fullgraph, emulate, linear_type, dtype ) @@ -85,31 +80,22 @@ def test_eager_only( @pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) -@pytest.mark.parametrize("use_activation_hooks", [False, True]) -# TODO this shouldn't fail but multiple fake modes -@pytest.mark.usefixtures("x_fail_activation_hooks") @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_aot_eager( fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtype, - use_activation_hooks: bool, ): - if linear_type == LinearType.DELAYED and use_activation_hooks: - pytest.skip("use_activation_hooks is only supported for dynamic linear") torch._dynamo.reset() _test_compile_base( - "aot_eager", fullgraph, emulate, linear_type, dtype, use_activation_hooks + "aot_eager", fullgraph, emulate, linear_type, dtype ) @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False]) @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) -@pytest.mark.parametrize("use_activation_hooks", [False, True]) -# TODO this shouldn't fail but multiple fake modes -@pytest.mark.usefixtures("x_fail_activation_hooks") @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA not available") @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_inductor( @@ -117,13 +103,10 @@ def test_inductor( emulate: bool, linear_type: bool, dtype: torch.dtype, - use_activation_hooks: bool, ): - if linear_type == LinearType.DELAYED and use_activation_hooks: - pytest.skip("use_activation_hooks is only supported for dynamic linear") torch._dynamo.reset() _test_compile_base( - "inductor", fullgraph, emulate, linear_type, dtype, use_activation_hooks + "inductor", fullgraph, emulate, linear_type, dtype ) From be02472f6839541efbcb7e3dbac6dd9b96a036ff Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Sun, 3 Mar 2024 22:06:27 -0800 Subject: [PATCH 2/6] format --- float8_experimental/float8_linear.py | 4 +++- test/test_compile.py | 12 +++--------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index ede7e208..cc65cfc3 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -313,7 +313,9 @@ def from_float(cls, mod, emulate: bool = False, cast_activation: bool = True): emulate (bool): whether to emulate fp8 matmul logic in float32 cast_activation (bool): whether to use activation hooks instead of inlining the casting logic """ - assert cast_activation is True, "cast activation option is not supported yet, we always cast activations!" + assert ( + cast_activation is True + ), "cast activation option is not supported yet, we always cast activations!" # TODO Follow up! This is a great idea but we need the mixin base to create real # Tensors and the Linear base to create empty params # with torch.device("meta"): diff --git a/test/test_compile.py b/test/test_compile.py index 2f7a83e3..2a9abba9 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -71,9 +71,7 @@ def test_eager_only( dtype: torch.dtype, ): torch._dynamo.reset() - _test_compile_base( - "eager", fullgraph, emulate, linear_type, dtype - ) + _test_compile_base("eager", fullgraph, emulate, linear_type, dtype) @pytest.mark.parametrize("fullgraph", [True]) @@ -88,9 +86,7 @@ def test_aot_eager( dtype: torch.dtype, ): torch._dynamo.reset() - _test_compile_base( - "aot_eager", fullgraph, emulate, linear_type, dtype - ) + _test_compile_base("aot_eager", fullgraph, emulate, linear_type, dtype) @pytest.mark.parametrize("fullgraph", [True]) @@ -105,9 +101,7 @@ def test_inductor( dtype: torch.dtype, ): torch._dynamo.reset() - _test_compile_base( - "inductor", fullgraph, emulate, linear_type, dtype - ) + _test_compile_base("inductor", fullgraph, emulate, linear_type, dtype) class TestGraphBreaks(DynamoTestCase): From 0136d1f34d31068d35d55b7fa32207d5cc22c66c Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Mon, 11 Mar 2024 21:52:24 -0700 Subject: [PATCH 3/6] switch to have tensor_casted_to_fp8 util function instead --- float8_experimental/float8_dynamic_linear.py | 51 ++++++-------------- float8_experimental/float8_linear.py | 5 +- float8_experimental/float8_linear_utils.py | 15 +----- float8_experimental/float8_tensor.py | 13 +++++ 4 files changed, 32 insertions(+), 52 deletions(-) diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 66f55895..4a5ac2e1 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -8,7 +8,11 @@ """ import torch -from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd +from float8_experimental.float8_tensor import ( + Float8Tensor, + tensor_already_casted_to_fp8, + to_fp8_no_autograd, +) from float8_experimental.float8_utils import tensor_to_scale @@ -30,6 +34,9 @@ def forward( @staticmethod def backward(ctx, gradY): + if tensor_already_casted_to_fp8(gradY): + # check to early return if already casted to float8 + return gradY, None gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2) fp8_tensor = to_fp8_no_autograd( gradY, gradY_scale, torch.float8_e5m2, ctx.emulate @@ -37,43 +44,18 @@ def backward(ctx, gradY): return fp8_tensor, None -def cast_x_to_float8_e4m3fn_pre_hook(module, args): - """ - Hook to cast the incoming activation to `torch.float8_e4m3fn` - """ - return module.cast_to_float8_e4m3fn(args[0]) - - -def cast_grad_to_float8_e5m2_backward_forward_hook(module, input, output): - """This is a forward hook that sends the output of the model through - a no-op in the forward but a cast to float8_e5m2 in the backward. - - Args: - module (nn.Module): the module to cast the output of - input (Tensor): the input to the module forward call - output (Tensor): the output of the module forward - """ - return module.cast_to_float8_e5m2_bw(output) - - class Float8DynamicLinear(torch.nn.Linear): """ A wrapper around a `torch.nn.Linear` module which does fp8 compute. By on the fly conversion to fp8 of the input and weight tensors. """ - def __init__(self, cast_activation: bool, **super_kwargs): - """ - Args: - cast_activation (bool): whether to do activation casting to and from float8 - """ + def __init__(self, **super_kwargs): super().__init__(**super_kwargs) - self.cast_activation = cast_activation - def forward(self, x): # cast x to float8_e4m3fn if not using activation hooks - x_fp8 = self.cast_to_float8_e4m3fn(x) if self.cast_activation else x + x_fp8 = self.cast_to_float8_e4m3fn(x) # cast w to float8_e4m3fn w_fp8 = self.cast_to_float8_e4m3fn(self.weight) @@ -81,12 +63,14 @@ def forward(self, x): y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) # Cast gradY to float8_e5m2 during backward if not using activation hooks - if self.cast_activation: - y = self.cast_to_float8_e5m2_bw(y) + y = self.cast_to_float8_e5m2_bw(y) return y def cast_to_float8_e4m3fn(self, inpt_tensor: torch.Tensor) -> Float8Tensor: + if tensor_already_casted_to_fp8(inpt_tensor): + # check to early return if already casted to float8 + return inpt_tensor scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn) return Float8Tensor.to_float8( inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate @@ -96,16 +80,13 @@ def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor: return NoopFwToFloat8E5M2Bw.apply(gradY, self.emulate) @classmethod - def from_float( - cls, mod, emulate: bool = False, cast_activation: bool = True - ) -> "Float8DynamicLinear": + def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear": """ Create an nn.Linear with fp8 compute from a regular nn.Linear Args: mod (torch.nn.Linear): nn.Linear to convert emulate (bool): whether to emulate fp8 matmul logic in float32 - cast_activation (bool): whether to do activation casting to and from float8 """ with torch.device("meta"): super_kwargs = { @@ -113,7 +94,7 @@ def from_float( "out_features": mod.out_features, "bias": False, } - new_mod = cls(cast_activation, **super_kwargs) + new_mod = cls(**super_kwargs) new_mod.weight = mod.weight new_mod.bias = mod.bias new_mod.emulate = emulate diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index cc65cfc3..4f6d277d 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -304,7 +304,7 @@ def forward(self, x): return y @classmethod - def from_float(cls, mod, emulate: bool = False, cast_activation: bool = True): + def from_float(cls, mod, emulate: bool = False): """ Create an nn.Linear with fp8 compute from a regular nn.Linear @@ -313,9 +313,6 @@ def from_float(cls, mod, emulate: bool = False, cast_activation: bool = True): emulate (bool): whether to emulate fp8 matmul logic in float32 cast_activation (bool): whether to use activation hooks instead of inlining the casting logic """ - assert ( - cast_activation is True - ), "cast activation option is not supported yet, we always cast activations!" # TODO Follow up! This is a great idea but we need the mixin base to create real # Tensors and the Linear base to create empty params # with torch.device("meta"): diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 1529f983..f7cdd1d6 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -33,14 +33,12 @@ def get_float8_linear( linear_type: LinearType, linear_ref: torch.nn.Linear, emulate: bool = False, - cast_activation: bool = True, ): """Returns a Float8Linear module of the given type, initialized from linear_ref. Args: linear_type: The type of Float8Linear to return. linear_ref: The linear module to initialize from. emulate: Whether to emulate the fp8 matmul logic in float32. - cast_activation: Whether to use activation hooks for dynamic linear. """ LINEAR_TYPE_MAP = { LinearType.DELAYED: Float8Linear, @@ -48,12 +46,9 @@ def get_float8_linear( } if linear_type not in LINEAR_TYPE_MAP: raise ValueError(f"linear_type must be one of {LINEAR_TYPE_MAP.keys()}") - if not cast_activation and linear_type != LinearType.DYNAMIC: - raise ValueError("cast_activation option is only supported for dynamic linear") return LINEAR_TYPE_MAP[linear_type].from_float( copy.deepcopy(linear_ref), emulate=emulate, - cast_activation=cast_activation, ) @@ -90,7 +85,6 @@ def swap_linear_with_float8_linear( *, skip_fqn_list: Optional[List[str]] = None, emulate: bool = False, - cast_activation: bool = True, ) -> nn.Module: """ Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances @@ -102,7 +96,6 @@ def swap_linear_with_float8_linear( skip_fqn_list (List[str], optional): If specified, a list of module FQNs to skip. Linear submodules of these skipped modules will also be skipped. emulate (bool): Whether to emulate the fp8 matmul logic in fp32. - cast_activation (bool): Whether to cast activations to fp8 using module hooks. """ module_names_to_skip = set(skip_fqn_list or []) if isinstance(module, nn.Linear): @@ -110,9 +103,7 @@ def swap_linear_with_float8_linear( raise AssertionError( f"Does not support a root nn.Linear with children: {module}" ) - return module_cls.from_float( - module, emulate=emulate, cast_activation=cast_activation - ) + return module_cls.from_float(module, emulate=emulate) # Mark all modules to skip as visited root_module = module @@ -134,9 +125,7 @@ def post_order_traversal( assert ( parent_module is not None ), f"Linear root module should return early: {module}" - float8linear_module = module_cls.from_float( - module, emulate=emulate, cast_activation=cast_activation - ) + float8linear_module = module_cls.from_float(module, emulate=emulate) setattr(parent_module, module_name, float8linear_module) post_order_traversal(root_module, "", None) diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 3647e185..f5873ea5 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -14,6 +14,19 @@ aten = torch.ops.aten +def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: + """ + Check if the tensor is already casted to fp8 + """ + if isinstance(tensor, Float8Tensor): + return True + elif isinstance(tensor, DTensor) and isinstance(tensor._local_tensor, Float8Tensor): + # TODO: shall we stick to public API and directly use tensor.to_local() here? + return True + + return False + + def to_fp8_no_autograd( x: torch.Tensor, x_scale: torch.Tensor, float8_dtype: torch.dtype, emulate: bool ) -> "Float8Tensor": From eea4595174483ff04ea782d758889ed9390126da Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 12 Mar 2024 10:48:34 -0700 Subject: [PATCH 4/6] remove conftest --- test/conftest.py | 24 ------------------------ 1 file changed, 24 deletions(-) delete mode 100644 test/conftest.py diff --git a/test/conftest.py b/test/conftest.py deleted file mode 100644 index 30e42fd4..00000000 --- a/test/conftest.py +++ /dev/null @@ -1,24 +0,0 @@ -import pytest - - -@pytest.fixture -def x_fail_activation_hooks(request): - use_activation_hooks = request.getfixturevalue("use_activation_hooks") - if use_activation_hooks: - request.node.add_marker( - pytest.mark.xfail( - reason="use_activation_hooks is not supported for AOT", strict=True - ) - ) - - -@pytest.fixture -def x_fail_activation_hooks_with_delayed(request): - linear_type = request.getfixturevalue("linear_type") - use_activation_hooks = request.getfixturevalue("use_activation_hooks") - if use_activation_hooks and linear_type == linear_type.DELAYED: - request.node.add_marker( - pytest.mark.xfail( - reason="use_activation_hooks is not supported for AOT", strict=True - ) - ) From 020642a8a324da713d15872a611e59778c1d4424 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 13 Mar 2024 10:17:51 -0700 Subject: [PATCH 5/6] recursive check already casted --- float8_experimental/float8_tensor.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index f5873ea5..830d9bb2 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -9,6 +9,7 @@ from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated +import torch.distributed._functional_collectives as funcol from torch.distributed._tensor import DTensor aten = torch.ops.aten @@ -20,9 +21,11 @@ def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: """ if isinstance(tensor, Float8Tensor): return True - elif isinstance(tensor, DTensor) and isinstance(tensor._local_tensor, Float8Tensor): + elif isinstance(tensor, DTensor): # TODO: shall we stick to public API and directly use tensor.to_local() here? - return True + return tensor_already_casted_to_fp8(tensor._local_tensor) + elif isinstance(tensor, funcol.AsyncCollectiveTensor): + return tensor_already_casted_to_fp8(tensor.elem) return False From a1cb6ef6d85b5ba9822d91e1c8fb9aa703343f06 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Thu, 14 Mar 2024 13:15:46 -0700 Subject: [PATCH 6/6] lint --- float8_experimental/float8_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 830d9bb2..365e6093 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -7,9 +7,9 @@ import torch -from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated - import torch.distributed._functional_collectives as funcol + +from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated from torch.distributed._tensor import DTensor aten = torch.ops.aten