Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit f64d339

Browse files
vkuzofacebook-github-bot
authored andcommitted
move tensor scaling configuration to Float8LinearConfig (#325)
Summary: Pull Request resolved: #325 Moves the fields configuring per-tensor scaling options to `Float8LinearConfig`, to set us up for upcoming new configurations such as other scaling granularities, and simplify configuraiton in general. Reviewed By: weifengpy Differential Revision: D60176980 fbshipit-source-id: 0950241b1f5d8ffe95dd67124e93fd4aaed50cf2
1 parent 27a3277 commit f64d339

19 files changed

+278
-215
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,19 @@ m = Model(...)
9595
# gated with config.enable_amax_init and
9696
# config.enable_pre_and_post_forward are needed for
9797
# autocast + compile + FSDP + float8 to work
98-
from float8_experimental import Float8LinearConfig
98+
from float8_experimental import Float8LinearConfig, TensorScalingType, Float8TensorCastConfig
9999
config = Float8LinearConfig(
100100
enable_amax_init = False, # only needed for autocast + compile + FSDP + float8 delayed
101101
enable_pre_and_post_forward, False # only needed for autocast + compile + FSDP + float8 delayed
102+
cast_config_input=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED),
103+
cast_config_weight=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED),
104+
cast_config_grad_output=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED),
102105
)
103106

104107
# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
105108
# type
106109
swap_linear_with_float8_linear(
107110
m,
108-
scaling_type_input=TensorScalingType.DELAYED,
109-
scaling_type_weight=TensorScalingType.DELAYED,
110-
scaling_type_grad_output=TensorScalingType.DELAYED,
111111
config=config,
112112
)
113113

benchmarks/bench_linear_float8.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414

1515
import torch
1616
import torch.utils.benchmark as benchmark
17-
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
17+
from float8_experimental.config import (
18+
Float8LinearConfig,
19+
Float8TensorCastConfig,
20+
TensorScalingType,
21+
)
22+
from float8_experimental.float8_linear import Float8Linear
1823
from float8_experimental.float8_linear_utils import (
1924
linear_requires_sync,
2025
sync_float8_amax_and_scale_history,
@@ -105,6 +110,13 @@ def main(
105110
scaling_type_input = TensorScalingType(scaling_type_input)
106111
scaling_type_weight = TensorScalingType(scaling_type_weight)
107112
scaling_type_grad_output = TensorScalingType(scaling_type_grad_output)
113+
config = Float8LinearConfig(
114+
cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input),
115+
cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight),
116+
cast_config_grad_output=Float8TensorCastConfig(
117+
scaling_type=scaling_type_grad_output
118+
),
119+
)
108120

109121
# LLaMa 2 70B single-node weight shapes
110122
# assumes fused attn.wqkv and ffn.w13
@@ -136,9 +148,7 @@ def main(
136148
linear_float8 = Float8Linear.from_float(
137149
copy.deepcopy(linear_ref),
138150
emulate=False,
139-
scaling_type_input=scaling_type_input,
140-
scaling_type_weight=scaling_type_weight,
141-
scaling_type_grad_output=scaling_type_grad_output,
151+
config=config,
142152
)
143153
scaling_repr = linear_float8.scaling_repr()
144154

@@ -153,9 +163,7 @@ def main(
153163
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()
154164

155165
def float8_forw_backward():
156-
if linear_requires_sync(
157-
scaling_type_input, scaling_type_weight, scaling_type_grad_output
158-
):
166+
if linear_requires_sync(config):
159167
sync_float8_amax_and_scale_history(linear_float8)
160168
linear_float8(input_tensor).sum().backward()
161169

benchmarks/bench_multi_gpu.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
import torch.multiprocessing as mp
1515
import torch.nn as nn
1616
import torch.utils.benchmark as benchmark
17-
from float8_experimental.float8_linear import TensorScalingType
17+
from float8_experimental.config import (
18+
Float8LinearConfig,
19+
Float8TensorCastConfig,
20+
TensorScalingType,
21+
)
1822
from float8_experimental.float8_linear_utils import (
1923
swap_linear_with_float8_linear,
2024
sync_float8_amax_and_scale_history,
@@ -28,6 +32,14 @@
2832
B, M, K, N = 32, 1024, 1024, 1024
2933
lr = 0.01
3034

35+
config = Float8LinearConfig(
36+
cast_config_input=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED),
37+
cast_config_weight=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED),
38+
cast_config_grad_output=Float8TensorCastConfig(
39+
scaling_type=TensorScalingType.DELAYED
40+
),
41+
)
42+
3143

3244
def benchmark_torch_function_in_microseconds(
3345
func: Callable,
@@ -68,9 +80,7 @@ def get_model(K, N, is_fp8, base_dtype=torch.float32):
6880
swap_linear_with_float8_linear(
6981
m,
7082
emulate=False,
71-
scaling_type_input=TensorScalingType.DELAYED,
72-
scaling_type_weight=TensorScalingType.DELAYED,
73-
scaling_type_grad_output=TensorScalingType.DELAYED,
83+
config=config,
7484
)
7585
return m
7686

benchmarks/profile_linear_float8.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
21-
from float8_experimental.float8_linear import TensorScalingType
21+
from float8_experimental.config import (
22+
Float8LinearConfig,
23+
Float8TensorCastConfig,
24+
TensorScalingType,
25+
)
2226
from float8_experimental.float8_linear_utils import (
2327
linear_requires_sync,
2428
swap_linear_with_float8_linear,
@@ -216,6 +220,13 @@ def main(
216220
scaling_type_input = TensorScalingType(scaling_type_input)
217221
scaling_type_weight = TensorScalingType(scaling_type_weight)
218222
scaling_type_grad_output = TensorScalingType(scaling_type_grad_output)
223+
config = Float8LinearConfig(
224+
cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input),
225+
cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight),
226+
cast_config_grad_output=Float8TensorCastConfig(
227+
scaling_type=scaling_type_grad_output
228+
),
229+
)
219230
scaling_repr = "_".join(
220231
[
221232
s.short_str()
@@ -256,14 +267,8 @@ def main(
256267

257268
m_ref = m_ref.to(device).to(ref_dtype)
258269

259-
extra_kwargs = {
260-
"scaling_type_input": scaling_type_input,
261-
"scaling_type_weight": scaling_type_weight,
262-
"scaling_type_grad_output": scaling_type_grad_output,
263-
}
264-
265270
m_float8 = copy.deepcopy(m_ref)
266-
swap_linear_with_float8_linear(m_float8, **extra_kwargs)
271+
swap_linear_with_float8_linear(m_float8, config=config)
267272

268273
def ref_forw_backward(x):
269274
out = m_ref(x)
@@ -281,9 +286,7 @@ def float8_forw_backward_wrapper(x):
281286
# inspection of the fw+bw torch.compile without the scale
282287
# syncing code
283288
# TODO(future): make this better
284-
if linear_requires_sync(
285-
scaling_type_input, scaling_type_weight, scaling_type_grad_output
286-
):
289+
if linear_requires_sync(config):
287290
with record_function("scale_amax_and_scales"):
288291
sync_amax_history(m_float8)
289292
out = float8_forw(x)

float8_experimental/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66
# Lets define a few top level things here
7-
from float8_experimental.config import Float8LinearConfig
7+
from float8_experimental.config import (
8+
Float8LinearConfig,
9+
Float8TensorCastConfig,
10+
TensorScalingType,
11+
)
812
from float8_experimental.float8_linear import Float8Linear
913
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
1014
from float8_experimental.float8_tensor import (
@@ -21,7 +25,9 @@
2125

2226
__all__ = [
2327
# configuration
28+
"TensorScalingType",
2429
"Float8LinearConfig",
30+
"Float8TensorCastConfig",
2531
# top level UX
2632
"swap_linear_with_float8_linear",
2733
# TODO(future): remove Float8Tensor and Float8Linear from public API

float8_experimental/config.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,49 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import enum
78
from dataclasses import dataclass
89

910

11+
class TensorScalingType(enum.Enum):
12+
DELAYED = "delayed"
13+
DYNAMIC = "dynamic"
14+
15+
def short_str(self):
16+
if self is TensorScalingType.DELAYED:
17+
return "del"
18+
else:
19+
assert self is TensorScalingType.DYNAMIC
20+
return "dyn"
21+
22+
23+
@dataclass(frozen=True)
24+
class Float8TensorCastConfig:
25+
"""
26+
Configuration for casting a single tensor to float8
27+
"""
28+
29+
scaling_type: TensorScalingType = TensorScalingType.DYNAMIC
30+
31+
1032
@dataclass(frozen=True)
1133
class Float8LinearConfig:
1234
"""
1335
Configuration for converting a `torch.nn.Linear` module to float8
1436
for training.
1537
"""
1638

39+
#
40+
# Per-tensor configuration for `input`, `weight`, `grad_output`
41+
#
42+
cast_config_input: Float8TensorCastConfig = Float8TensorCastConfig()
43+
cast_config_weight: Float8TensorCastConfig = Float8TensorCastConfig()
44+
cast_config_grad_output: Float8TensorCastConfig = Float8TensorCastConfig()
45+
46+
#
47+
# Per-linear configuration
48+
#
49+
1750
# If True, on the first iteration of Float8Linear the amaxes will be
1851
# initialized with the incoming data. As of 2023-12-30, this doesn't work
1952
# with autocast + torch.compile + FSDP. Enabling this option is nice for

float8_experimental/float8_linear.py

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import torch
1616

17-
from float8_experimental.config import Float8LinearConfig
17+
from float8_experimental.config import Float8LinearConfig, TensorScalingType
1818

1919
from float8_experimental.float8_dynamic_utils import (
2020
cast_to_float8_e4m3_dynamic,
@@ -148,18 +148,6 @@ def __init__(self, history_len: int = 16, scale_fn_name: str = "max"):
148148
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."
149149

150150

151-
class TensorScalingType(enum.Enum):
152-
DELAYED = "delayed"
153-
DYNAMIC = "dynamic"
154-
155-
def short_str(self):
156-
if self is TensorScalingType.DELAYED:
157-
return "del"
158-
else:
159-
assert self is TensorScalingType.DYNAMIC
160-
return "dyn"
161-
162-
163151
class Float8Linear(torch.nn.Linear):
164152
"""
165153
A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks
@@ -170,9 +158,6 @@ def __init__(self, *args, **kwargs):
170158
"""
171159
Additional arguments on top of `torch.nn.Linear`'s arguments:
172160
* `delayed_scaling_recipe`: configuration for delayed scaling
173-
* `scaling_type_input`: delayed vs dynamic scaling for `input`
174-
* `scaling_type_weight`: delayed vs dynamic scaling for `weight`
175-
* `scaling_type_grad_output`: delayed vs dynamic scaling for `grad_output`
176161
* `config`: Float8LinearConfig
177162
"""
178163

@@ -182,20 +167,13 @@ def __init__(self, *args, **kwargs):
182167
# Amax scales should always be kept as float32.
183168
self.always_float32_buffers = set()
184169
emulate = kwargs.pop("emulate", False)
185-
scaling_type_input = kwargs.pop("scaling_type_input", TensorScalingType.DYNAMIC)
186-
scaling_type_weight = kwargs.pop(
187-
"scaling_type_weight", TensorScalingType.DYNAMIC
188-
)
189-
scaling_type_grad_output = kwargs.pop(
190-
"scaling_type_grad_output", TensorScalingType.DYNAMIC
191-
)
192170
config = kwargs.pop("config")
193171
super().__init__(*args, **kwargs)
194172

195173
# Defines the scaling behavior of input, weight, grad_output
196-
self.scaling_type_input = scaling_type_input
197-
self.scaling_type_weight = scaling_type_weight
198-
self.scaling_type_grad_output = scaling_type_grad_output
174+
self.scaling_type_input = config.cast_config_input.scaling_type
175+
self.scaling_type_weight = config.cast_config_weight.scaling_type
176+
self.scaling_type_grad_output = config.cast_config_grad_output.scaling_type
199177
# Convenience flag to skip code related to delayed scaling
200178
self.has_any_delayed_scaling = (
201179
self.scaling_type_input is TensorScalingType.DELAYED
@@ -457,9 +435,6 @@ def from_float(
457435
cls,
458436
mod,
459437
emulate: bool = False,
460-
scaling_type_input=TensorScalingType.DYNAMIC,
461-
scaling_type_weight=TensorScalingType.DYNAMIC,
462-
scaling_type_grad_output=TensorScalingType.DYNAMIC,
463438
config: Optional[Float8LinearConfig] = None,
464439
):
465440
"""
@@ -477,9 +452,6 @@ def from_float(
477452
mod.in_features,
478453
mod.out_features,
479454
bias=False,
480-
scaling_type_input=scaling_type_input,
481-
scaling_type_weight=scaling_type_weight,
482-
scaling_type_grad_output=scaling_type_grad_output,
483455
emulate=emulate,
484456
config=config,
485457
)
@@ -495,15 +467,17 @@ def from_float(
495467
# 2. buffers need to be already created for the delayed scaling version
496468
# of the weight wrapper to be initialized
497469
if config.enable_fsdp_fp8_all_gather:
498-
if scaling_type_weight is TensorScalingType.DYNAMIC:
470+
if config.cast_config_weight.scaling_type is TensorScalingType.DYNAMIC:
499471
new_mod.weight = torch.nn.Parameter(
500472
WeightWithDynamicFloat8CastTensor(
501473
new_mod.weight,
502474
new_mod.linear_mm_config,
503475
)
504476
)
505477
else:
506-
assert scaling_type_weight is TensorScalingType.DELAYED
478+
assert (
479+
config.cast_config_weight.scaling_type is TensorScalingType.DELAYED
480+
)
507481
new_mod.weight = torch.nn.Parameter(
508482
WeightWithDelayedFloat8CastTensor(
509483
new_mod.weight,

0 commit comments

Comments
 (0)