Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 27a3277

Browse files
vkuzofacebook-github-bot
authored andcommitted
move most previously-global configs to Float8LinearConfig (#324)
Summary: Pull Request resolved: #324 Adds a `Float8LinearConfig` to unify the user facing per-linear configuration, and moves most of the previously global config options there. In future PRs (to keep PRs small), we will move emulation, scaling and gemm configurations to also live here. Reviewed By: weifengpy Differential Revision: D60176981 fbshipit-source-id: 84ed7a2d0d72aee425f870786b56b8bd641595b1
1 parent 603efc2 commit 27a3277

File tree

9 files changed

+183
-109
lines changed

9 files changed

+183
-109
lines changed

README.md

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,20 +91,27 @@ from float8_experimental.float8_linear import TensorScalingType
9191
# create model
9292
m = Model(...)
9393

94+
# optional: configure for compatibility with FSDP. Note that workarounds
95+
# gated with config.enable_amax_init and
96+
# config.enable_pre_and_post_forward are needed for
97+
# autocast + compile + FSDP + float8 to work
98+
from float8_experimental import Float8LinearConfig
99+
config = Float8LinearConfig(
100+
enable_amax_init = False, # only needed for autocast + compile + FSDP + float8 delayed
101+
enable_pre_and_post_forward, False # only needed for autocast + compile + FSDP + float8 delayed
102+
)
103+
94104
# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
95105
# type
96106
swap_linear_with_float8_linear(
97107
m,
98108
scaling_type_input=TensorScalingType.DELAYED,
99109
scaling_type_weight=TensorScalingType.DELAYED,
100110
scaling_type_grad_output=TensorScalingType.DELAYED,
111+
config=config,
101112
)
102113

103-
# optional: use FSDP. Note that workarounds gated with config.enable_amax_init and
104-
# config.enable_pre_and_post_forward are needed for autocast + compile + FSDP + float8 to work
105-
from float8_experimental import config
106-
config.enable_amax_init = False # only needed for autocast + compile + FSDP + float8 delayed
107-
config.enable_pre_and_post_forward = False # only needed for autocast + compile + FSDP + float8 delayed
114+
# optional: use FSDP
108115
model = FSDP(model, use_orig_params=True)
109116

110117
# optional: enable torch.compile for improved performance

float8_experimental/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66
# Lets define a few top level things here
7+
from float8_experimental.config import Float8LinearConfig
78
from float8_experimental.float8_linear import Float8Linear
9+
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
810
from float8_experimental.float8_tensor import (
911
Float8Tensor,
1012
GemmInputRole,
@@ -17,4 +19,12 @@
1719

1820
add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole, LinearMMConfig])
1921

20-
__all__ = ["Float8Tensor", "Float8Linear"]
22+
__all__ = [
23+
# configuration
24+
"Float8LinearConfig",
25+
# top level UX
26+
"swap_linear_with_float8_linear",
27+
# TODO(future): remove Float8Tensor and Float8Linear from public API
28+
"Float8Tensor",
29+
"Float8Linear",
30+
]

float8_experimental/config.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,40 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# If True, on the first iteration of Float8Linear the amaxes will be
8-
# initialized with the incoming data. As of 2023-12-30, this doesn't work
9-
# with autocast + torch.compile + FSDP. Enabling this option is nice for
10-
# testing, but this is not necessary for real training jobs.
11-
enable_amax_init = True
12-
13-
# If True, pre-forward and post-forward functions are run. As of 2023-12-30,
14-
# this doesn't work with autocast + torch.compile + FSDP. Enabling this
15-
# option is useful for safety, but not strictly necessary.
16-
enable_pre_and_post_forward = True
17-
18-
# If True, then uses a tensor subclass for the fp8 linear module's weight that
19-
# implements pre/post-all-gather methods to do fp8 all-gather with FSDP2.
20-
# Only dynamic scaling is supported for now.
21-
enable_fsdp_fp8_all_gather = False
7+
from dataclasses import dataclass
8+
9+
10+
@dataclass(frozen=True)
11+
class Float8LinearConfig:
12+
"""
13+
Configuration for converting a `torch.nn.Linear` module to float8
14+
for training.
15+
"""
16+
17+
# If True, on the first iteration of Float8Linear the amaxes will be
18+
# initialized with the incoming data. As of 2023-12-30, this doesn't work
19+
# with autocast + torch.compile + FSDP. Enabling this option is nice for
20+
# testing, but this is not necessary for real training jobs.
21+
enable_amax_init: bool = True
22+
23+
# If True, pre-forward and post-forward functions are run. As of 2023-12-30,
24+
# this doesn't work with autocast + torch.compile + FSDP. Enabling this
25+
# option is useful for safety, but not strictly necessary.
26+
enable_pre_and_post_forward: bool = True
27+
28+
# If True, then uses a tensor subclass for the fp8 linear module's weight that
29+
# implements pre/post-all-gather methods to do fp8 all-gather with FSDP2.
30+
# Only dynamic scaling is supported for now.
31+
enable_fsdp_fp8_all_gather: bool = False
32+
33+
# If True, then prior to performing the fp8 scaled mamtmul we will pad the
34+
# inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls
35+
# _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16.
36+
# This can cause a memory spike however so we keep this off by default.
37+
pad_inner_dim: bool = False
38+
2239

2340
# If True, use 'fnuz' float8 types for calculations.
2441
# Currently, ROCm only supports fnuz variants.
42+
# TODO(future PR): move this to Float8LinearConfig
2543
use_fnuz_dtype = False
26-
27-
# If True, then prior to performing the fp8 scaled mamtmul we will pad the
28-
# inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls
29-
# _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16.
30-
# This can cause a memory spike however so we keep this off by default.
31-
pad_inner_dim = False

float8_experimental/float8_linear.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212

1313
from typing import Optional
1414

15-
import float8_experimental.config as config
16-
1715
import torch
1816

17+
from float8_experimental.config import Float8LinearConfig
18+
1919
from float8_experimental.float8_dynamic_utils import (
2020
cast_to_float8_e4m3_dynamic,
2121
cast_to_float8_e5m2_dynamic_bw,
@@ -173,6 +173,7 @@ def __init__(self, *args, **kwargs):
173173
* `scaling_type_input`: delayed vs dynamic scaling for `input`
174174
* `scaling_type_weight`: delayed vs dynamic scaling for `weight`
175175
* `scaling_type_grad_output`: delayed vs dynamic scaling for `grad_output`
176+
* `config`: Float8LinearConfig
176177
"""
177178

178179
delayed_scaling_recipe = kwargs.pop(
@@ -188,6 +189,7 @@ def __init__(self, *args, **kwargs):
188189
scaling_type_grad_output = kwargs.pop(
189190
"scaling_type_grad_output", TensorScalingType.DYNAMIC
190191
)
192+
config = kwargs.pop("config")
191193
super().__init__(*args, **kwargs)
192194

193195
# Defines the scaling behavior of input, weight, grad_output
@@ -201,6 +203,8 @@ def __init__(self, *args, **kwargs):
201203
or self.scaling_type_grad_output is TensorScalingType.DELAYED
202204
)
203205

206+
self.config = config
207+
204208
# TODO(future): have a unique recipe per buffer instead of one per
205209
# module, saving implementing that until we need it.
206210
# TODO(future): serialization for recipes
@@ -212,36 +216,42 @@ def __init__(self, *args, **kwargs):
212216
self.linear_mm_config = LinearMMConfig(
213217
# input
214218
ScaledMMConfig(
215-
emulate, True if not emulate else False, False, config.pad_inner_dim
219+
emulate,
220+
True if not emulate else False,
221+
False,
222+
self.config.pad_inner_dim,
216223
),
217224
# weight
218225
ScaledMMConfig(
219-
emulate, True if not emulate else False, False, config.pad_inner_dim
226+
emulate,
227+
True if not emulate else False,
228+
False,
229+
self.config.pad_inner_dim,
220230
),
221231
# grad_output
222-
ScaledMMConfig(emulate, False, False, config.pad_inner_dim),
232+
ScaledMMConfig(emulate, False, False, self.config.pad_inner_dim),
223233
)
224234

225235
# Note: is_amax_initialized is not a buffer to avoid data dependent
226236
# control flow visible to dynamo
227237
# TODO(future PR): add serialization for this flag
228-
self.is_amax_initialized = not config.enable_amax_init
238+
self.is_amax_initialized = not self.config.enable_amax_init
229239

230240
# Syncing of amaxes and scales happens outside of this function. This
231241
# flag is here to enforce that the user does not forget to do this.
232-
self.amax_and_scale_synced = not config.enable_amax_init
242+
self.amax_and_scale_synced = not self.config.enable_amax_init
233243

234244
# This is needed to properly handle autocast in the amax/scale
235245
# update function for torch.float16
236246
self.last_seen_input_dtype = None
237247

238248
# pre_forward and post_forward are currently broken with FSDP
239249
# and torch.compile, this option can disable them
240-
# Note that when using `config.enable_pre_and_post_forward = False`,
241-
# it's recommended to also set `config.enable_amax_init = False`.
250+
# Note that when using `self.config.enable_pre_and_post_forward = False`,
251+
# it's recommended to also set `self.config.enable_amax_init = False`.
242252
# Otherwise, the amax buffer would never be marked as initialized and
243253
# would be initialized in every iteration.
244-
self.enable_pre_and_post_forward = config.enable_pre_and_post_forward
254+
self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward
245255

246256
def create_buffers(self):
247257
# Default values for history buffers, see above TODO
@@ -450,14 +460,18 @@ def from_float(
450460
scaling_type_input=TensorScalingType.DYNAMIC,
451461
scaling_type_weight=TensorScalingType.DYNAMIC,
452462
scaling_type_grad_output=TensorScalingType.DYNAMIC,
463+
config: Optional[Float8LinearConfig] = None,
453464
):
454465
"""
455466
Create an nn.Linear with fp8 compute from a regular nn.Linear
456467
457468
Args:
458469
mod (torch.nn.Linear): nn.Linear to convert
459470
emulate (bool): whether to emulate fp8 matmul logic in float32
471+
config (Optional[Float8LinearConfig]): configuration for conversion to float8
460472
"""
473+
if config is None:
474+
config = Float8LinearConfig()
461475
with torch.device("meta"):
462476
new_mod = cls(
463477
mod.in_features,
@@ -467,6 +481,7 @@ def from_float(
467481
scaling_type_weight=scaling_type_weight,
468482
scaling_type_grad_output=scaling_type_grad_output,
469483
emulate=emulate,
484+
config=config,
470485
)
471486
new_mod.weight = mod.weight
472487
new_mod.bias = mod.bias

float8_experimental/float8_linear_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
import torch.distributed as dist
1111
import torch.nn as nn
12+
from float8_experimental.config import Float8LinearConfig
1213
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
1314

1415
from float8_experimental.float8_utils import (
@@ -135,6 +136,7 @@ def swap_linear_with_float8_linear(
135136
scaling_type_input: TensorScalingType = TensorScalingType.DYNAMIC,
136137
scaling_type_weight: TensorScalingType = TensorScalingType.DYNAMIC,
137138
scaling_type_grad_output: TensorScalingType = TensorScalingType.DYNAMIC,
139+
config: Float8LinearConfig = None,
138140
) -> Optional[nn.Module]:
139141
"""
140142
Swaps `torch.nn.Linear` in `module` with `Float8Linear`.
@@ -148,16 +150,20 @@ def swap_linear_with_float8_linear(
148150
scaling_type_input (TensorScalingType): scaling type for `input`
149151
scaling_type_weight (TensorScalingType): scaling type for `weight`
150152
scaling_type_grad_output (TensorScalingType): scaling type for `grad_output`
153+
config (Float8LinearConfig): configuration for conversion to float8
151154
152155
Returns:
153156
nn.Module: The modified module with swapped linear layers.
154157
"""
158+
if config is None:
159+
config = Float8LinearConfig()
155160
from_float = lambda m: Float8Linear.from_float(
156161
m,
157162
emulate=emulate,
158163
scaling_type_input=scaling_type_input,
159164
scaling_type_weight=scaling_type_weight,
160165
scaling_type_grad_output=scaling_type_grad_output,
166+
config=config,
161167
)
162168
return swap_linear_layers(
163169
module,

float8_experimental/inference.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
from enum import auto, Enum
1313
from typing import Callable, List, Optional
1414

15-
import float8_experimental.config as config
16-
1715
import torch
1816
import torch.nn as nn
1917
from float8_experimental.float8_linear_utils import swap_linear_layers
@@ -55,6 +53,12 @@ class QuantConfig:
5553
activation_casting: ActivationCasting
5654
static_quantization_scale: Optional[torch.Tensor] = None
5755

56+
# If True, then prior to performing the fp8 scaled mamtmul we will pad the
57+
# inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls
58+
# _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16.
59+
# This can cause a memory spike however so we keep this off by default.
60+
pad_inner_dim = False
61+
5862
def __post_init__(self):
5963
if self.activation_casting == ActivationCasting.STATIC:
6064
assert isinstance(
@@ -151,7 +155,7 @@ def from_float(
151155
quant_config (QuantConfig): Configuration for the weight and activation casting
152156
"""
153157
forward_config = ScaledMMConfig(
154-
False, use_fast_accum, pad_inner_dim=config.pad_inner_dim
158+
False, use_fast_accum, pad_inner_dim=quant_config.pad_inner_dim
155159
)
156160
linear_mm_config = LinearMMConfig(
157161
forward_config, forward_config, forward_config

0 commit comments

Comments
 (0)