Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit bfc60fb

Browse files
wanchaolfacebook-github-bot
authored andcommitted
early return by check tensor already casted or not (#233)
Summary: as titled, it turns out we don't need to install additional hooks base on our TP + FP8 design. The only thing we need to do here is to be able to turn off activation casting, so that we can put activation casting in the TP hooks So just check if the tensor already been casted to fp8 or not, as in TP we would cast activation into the DTensor's Float8Colwise/Rowwise instead. Pull Request resolved: #233 Reviewed By: drisspg Differential Revision: D54497744 Pulled By: wanchaol fbshipit-source-id: 20c4f6799bf91716778e2257388cb53b13373064
1 parent f094e75 commit bfc60fb

File tree

8 files changed

+46
-129
lines changed

8 files changed

+46
-129
lines changed

float8_experimental/config.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,3 @@
1414
# this doesn't work with autocast + torch.compile + FSDP. Enabling this
1515
# option is useful for safety, but not strictly necessary.
1616
enable_pre_and_post_forward = True
17-
18-
# If True, dynamic linear uses hooks for activation casting
19-
# TODO(before land): add test coverage for both cases
20-
# dynamic_use_activation_hooks = True
21-
# dynamic_use_activation_hooks = False

float8_experimental/float8_dynamic_linear.py

Lines changed: 16 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
"""
99
import torch
1010

11-
from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd
11+
from float8_experimental.float8_tensor import (
12+
Float8Tensor,
13+
tensor_already_casted_to_fp8,
14+
to_fp8_no_autograd,
15+
)
1216
from float8_experimental.float8_utils import tensor_to_scale
1317

1418

@@ -30,63 +34,43 @@ def forward(
3034

3135
@staticmethod
3236
def backward(ctx, gradY):
37+
if tensor_already_casted_to_fp8(gradY):
38+
# check to early return if already casted to float8
39+
return gradY, None
3340
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
3441
fp8_tensor = to_fp8_no_autograd(
3542
gradY, gradY_scale, torch.float8_e5m2, ctx.emulate
3643
)
3744
return fp8_tensor, None
3845

3946

40-
def cast_x_to_float8_e4m3fn_pre_hook(module, args):
41-
"""
42-
Hook to cast the incoming activation to `torch.float8_e4m3fn`
43-
"""
44-
return module.cast_to_float8_e4m3fn(args[0])
45-
46-
47-
def cast_grad_to_float8_e5m2_backward_forward_hook(module, input, output):
48-
"""This is a forward hook that sends the output of the model through
49-
a no-op in the forward but a cast to float8_e5m2 in the backward.
50-
51-
Args:
52-
module (nn.Module): the module to cast the output of
53-
input (Tensor): the input to the module forward call
54-
output (Tensor): the output of the module forward
55-
"""
56-
return module.cast_to_float8_e5m2_bw(output)
57-
58-
5947
class Float8DynamicLinear(torch.nn.Linear):
6048
"""
6149
A wrapper around a `torch.nn.Linear` module which does fp8 compute. By on the fly
6250
conversion to fp8 of the input and weight tensors.
6351
"""
6452

65-
def __init__(self, use_activation_hooks: bool, **super_kwargs):
66-
"""
67-
Args:
68-
use_activation_hooks (bool): whether to use activation hooks for casting to and from float8
69-
"""
53+
def __init__(self, **super_kwargs):
7054
super().__init__(**super_kwargs)
7155

72-
self.use_activation_hooks = use_activation_hooks
73-
7456
def forward(self, x):
7557
# cast x to float8_e4m3fn if not using activation hooks
76-
x_fp8 = x if self.use_activation_hooks else self.cast_to_float8_e4m3fn(x)
58+
x_fp8 = self.cast_to_float8_e4m3fn(x)
7759

7860
# cast w to float8_e4m3fn
7961
w_fp8 = self.cast_to_float8_e4m3fn(self.weight)
8062

8163
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
8264

8365
# Cast gradY to float8_e5m2 during backward if not using activation hooks
84-
if not self.use_activation_hooks:
85-
y = self.cast_to_float8_e5m2_bw(y)
66+
y = self.cast_to_float8_e5m2_bw(y)
8667

8768
return y
8869

8970
def cast_to_float8_e4m3fn(self, inpt_tensor: torch.Tensor) -> Float8Tensor:
71+
if tensor_already_casted_to_fp8(inpt_tensor):
72+
# check to early return if already casted to float8
73+
return inpt_tensor
9074
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
9175
return Float8Tensor.to_float8(
9276
inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate
@@ -96,31 +80,22 @@ def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor:
9680
return NoopFwToFloat8E5M2Bw.apply(gradY, self.emulate)
9781

9882
@classmethod
99-
def from_float(
100-
cls, mod, emulate: bool = False, use_activation_hooks: bool = False
101-
) -> "Float8DynamicLinear":
83+
def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
10284
"""
10385
Create an nn.Linear with fp8 compute from a regular nn.Linear
10486
10587
Args:
10688
mod (torch.nn.Linear): nn.Linear to convert
10789
emulate (bool): whether to emulate fp8 matmul logic in float32
108-
use_activation_hooks (bool): whether to use activation hooks for casting to and from float8
10990
"""
11091
with torch.device("meta"):
11192
super_kwargs = {
11293
"in_features": mod.in_features,
11394
"out_features": mod.out_features,
11495
"bias": False,
11596
}
116-
new_mod = cls(use_activation_hooks, **super_kwargs)
97+
new_mod = cls(**super_kwargs)
11798
new_mod.weight = mod.weight
11899
new_mod.bias = mod.bias
119100
new_mod.emulate = emulate
120-
if new_mod.use_activation_hooks:
121-
# install the hooks
122-
new_mod.register_forward_pre_hook(cast_x_to_float8_e4m3fn_pre_hook)
123-
new_mod.register_forward_hook(
124-
cast_grad_to_float8_e5m2_backward_forward_hook
125-
)
126101
return new_mod

float8_experimental/float8_linear.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,16 +304,15 @@ def forward(self, x):
304304
return y
305305

306306
@classmethod
307-
def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = False):
307+
def from_float(cls, mod, emulate: bool = False):
308308
"""
309309
Create an nn.Linear with fp8 compute from a regular nn.Linear
310310
311311
Args:
312312
mod (torch.nn.Linear): nn.Linear to convert
313313
emulate (bool): whether to emulate fp8 matmul logic in float32
314-
use_activation_hooks (bool): whether to use activation hooks instead of inlining the casting logic
314+
cast_activation (bool): whether to use activation hooks instead of inlining the casting logic
315315
"""
316-
assert not use_activation_hooks, "use_activation_hooks is not supported yet!"
317316
# TODO Follow up! This is a great idea but we need the mixin base to create real
318317
# Tensors and the Linear base to create empty params
319318
# with torch.device("meta"):

float8_experimental/float8_linear_utils.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,27 +33,22 @@ def get_float8_linear(
3333
linear_type: LinearType,
3434
linear_ref: torch.nn.Linear,
3535
emulate: bool = False,
36-
use_activation_hooks: bool = False,
3736
):
3837
"""Returns a Float8Linear module of the given type, initialized from linear_ref.
3938
Args:
4039
linear_type: The type of Float8Linear to return.
4140
linear_ref: The linear module to initialize from.
4241
emulate: Whether to emulate the fp8 matmul logic in float32.
43-
use_activation_hooks: Whether to use activation hooks for dynamic linear.
4442
"""
4543
LINEAR_TYPE_MAP = {
4644
LinearType.DELAYED: Float8Linear,
4745
LinearType.DYNAMIC: Float8DynamicLinear,
4846
}
4947
if linear_type not in LINEAR_TYPE_MAP:
5048
raise ValueError(f"linear_type must be one of {LINEAR_TYPE_MAP.keys()}")
51-
if use_activation_hooks and linear_type != LinearType.DYNAMIC:
52-
raise ValueError("use_activation_hooks is only supported for dynamic linear")
5349
return LINEAR_TYPE_MAP[linear_type].from_float(
5450
copy.deepcopy(linear_ref),
5551
emulate=emulate,
56-
use_activation_hooks=use_activation_hooks,
5752
)
5853

5954

@@ -104,7 +99,6 @@ def swap_linear_with_float8_linear(
10499
*,
105100
skip_fqn_list: Optional[List[str]] = None,
106101
emulate: bool = False,
107-
use_activation_hooks: bool = False,
108102
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
109103
) -> nn.Module:
110104
"""
@@ -117,7 +111,6 @@ def swap_linear_with_float8_linear(
117111
skip_fqn_list (List[str], optional): If specified, a list of module FQNs to skip.
118112
Linear submodules of these skipped modules will also be skipped.
119113
emulate (bool): Whether to emulate the fp8 matmul logic in fp32.
120-
use_activation_hooks (bool): Whether to cast activations to fp8 using module hooks.
121114
linear_layer_filter (Optional[Callable[[nn.Linear], bool]]): If specified, only the linear layers
122115
that pass the filter function will be swapped.
123116
"""
@@ -129,9 +122,7 @@ def swap_linear_with_float8_linear(
129122
raise AssertionError(
130123
f"Does not support a root nn.Linear with children: {module}"
131124
)
132-
return module_cls.from_float(
133-
module, emulate=emulate, use_activation_hooks=use_activation_hooks
134-
)
125+
return module_cls.from_float(module, emulate=emulate)
135126

136127
# Mark all modules to skip as visited
137128
root_module = module
@@ -155,9 +146,7 @@ def post_order_traversal(
155146
assert (
156147
parent_module is not None
157148
), f"Linear root module should return early: {module}"
158-
float8linear_module = module_cls.from_float(
159-
module, emulate=emulate, use_activation_hooks=use_activation_hooks
160-
)
149+
float8linear_module = module_cls.from_float(module, emulate=emulate)
161150
setattr(parent_module, module_name, float8linear_module)
162151

163152
post_order_traversal(root_module, "", None)

float8_experimental/float8_tensor.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,29 @@
77

88
import torch
99

10-
from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated
10+
import torch.distributed._functional_collectives as funcol
1111

12+
from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated
1213
from torch.distributed._tensor import DTensor
1314

1415
aten = torch.ops.aten
1516

1617

18+
def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool:
19+
"""
20+
Check if the tensor is already casted to fp8
21+
"""
22+
if isinstance(tensor, Float8Tensor):
23+
return True
24+
elif isinstance(tensor, DTensor):
25+
# TODO: shall we stick to public API and directly use tensor.to_local() here?
26+
return tensor_already_casted_to_fp8(tensor._local_tensor)
27+
elif isinstance(tensor, funcol.AsyncCollectiveTensor):
28+
return tensor_already_casted_to_fp8(tensor.elem)
29+
30+
return False
31+
32+
1733
def to_fp8_no_autograd(
1834
x: torch.Tensor, x_scale: torch.Tensor, float8_dtype: torch.dtype, emulate: bool
1935
) -> "Float8Tensor":

test/conftest.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

test/test_base.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,8 @@ def _test_linear_impl(
6060
m_ref,
6161
linear_type: LinearType,
6262
emulate: bool,
63-
use_activation_hooks: bool = False,
6463
):
65-
m_fp8 = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks)
64+
m_fp8 = get_float8_linear(linear_type, m_ref, emulate)
6665
for _ in range(2):
6766
if linear_requires_sync(linear_type):
6867
sync_float8_amax_and_scale_history(m_fp8)
@@ -123,15 +122,12 @@ def _test_linear_impl(
123122
@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
124123
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
125124
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
126-
@pytest.mark.parametrize("use_activation_hooks", [True, False])
127-
@pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed")
128125
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
129126
def test_linear_nobias(
130127
self,
131128
x_shape,
132129
linear_type: LinearType,
133130
emulate: bool,
134-
use_activation_hooks: bool,
135131
):
136132
if not emulate:
137133
if not torch.cuda.is_available():
@@ -145,24 +141,21 @@ def test_linear_nobias(
145141

146142
x = torch.randn(*x_shape, device="cuda")
147143
m_ref = nn.Linear(16, 32, bias=False, device="cuda")
148-
self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks)
144+
self._test_linear_impl(x, m_ref, linear_type, emulate)
149145

150146
@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
151147
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
152148
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
153149
@pytest.mark.parametrize(
154150
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
155151
)
156-
@pytest.mark.parametrize("use_activation_hooks", [True, False])
157-
@pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed")
158152
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
159153
def test_linear_bias(
160154
self,
161155
x_shape,
162156
linear_type: LinearType,
163157
emulate: bool,
164158
linear_dtype: torch.dtype,
165-
use_activation_hooks: bool,
166159
):
167160
if not emulate:
168161
if not torch.cuda.is_available():
@@ -176,22 +169,19 @@ def test_linear_bias(
176169

177170
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
178171
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype)
179-
self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks)
172+
self._test_linear_impl(x, m_ref, linear_type, emulate)
180173

181174
@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
182175
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
183176
@pytest.mark.parametrize(
184177
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
185178
)
186-
@pytest.mark.parametrize("use_activation_hooks", [True, False])
187-
@pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed")
188179
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
189180
def test_autocast_outputs(
190181
self,
191182
linear_type: LinearType,
192183
emulate: bool,
193184
linear_dtype: torch.dtype,
194-
use_activation_hooks: bool,
195185
):
196186
if not emulate:
197187
if not torch.cuda.is_available():
@@ -204,7 +194,7 @@ def test_autocast_outputs(
204194
pytest.skip()
205195

206196
m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
207-
m = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks)
197+
m = get_float8_linear(linear_type, m_ref, emulate)
208198

209199
# autocast off
210200
x = torch.randn(16, 32, device="cuda", dtype=linear_dtype)
@@ -242,7 +232,7 @@ def test_type_cast(
242232
)
243233

244234
m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
245-
m = get_float8_linear(linear_type, m, emulate, False)
235+
m = get_float8_linear(linear_type, m, emulate)
246236

247237
# Cast the module to dtype
248238
m = m.to(dtype=linear_dtype)

0 commit comments

Comments
 (0)