Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
move tensor scaling configuration to Float8LinearConfig (#325)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #325

Moves the fields configuring per-tensor scaling options to
`Float8LinearConfig`, to set us up for upcoming new configurations such
as other scaling granularities, and simplify configuraiton in general.

Reviewed By: weifengpy

Differential Revision: D60176980

fbshipit-source-id: 0950241b1f5d8ffe95dd67124e93fd4aaed50cf2
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jul 24, 2024
1 parent 27a3277 commit f64d339
Show file tree
Hide file tree
Showing 19 changed files with 278 additions and 215 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,19 @@ m = Model(...)
# gated with config.enable_amax_init and
# config.enable_pre_and_post_forward are needed for
# autocast + compile + FSDP + float8 to work
from float8_experimental import Float8LinearConfig
from float8_experimental import Float8LinearConfig, TensorScalingType, Float8TensorCastConfig
config = Float8LinearConfig(
enable_amax_init = False, # only needed for autocast + compile + FSDP + float8 delayed
enable_pre_and_post_forward, False # only needed for autocast + compile + FSDP + float8 delayed
cast_config_input=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED),
cast_config_weight=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED),
cast_config_grad_output=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED),
)

# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
# type
swap_linear_with_float8_linear(
m,
scaling_type_input=TensorScalingType.DELAYED,
scaling_type_weight=TensorScalingType.DELAYED,
scaling_type_grad_output=TensorScalingType.DELAYED,
config=config,
)

Expand Down
22 changes: 15 additions & 7 deletions benchmarks/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@

import torch
import torch.utils.benchmark as benchmark
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.config import (
Float8LinearConfig,
Float8TensorCastConfig,
TensorScalingType,
)
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
linear_requires_sync,
sync_float8_amax_and_scale_history,
Expand Down Expand Up @@ -105,6 +110,13 @@ def main(
scaling_type_input = TensorScalingType(scaling_type_input)
scaling_type_weight = TensorScalingType(scaling_type_weight)
scaling_type_grad_output = TensorScalingType(scaling_type_grad_output)
config = Float8LinearConfig(
cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input),
cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight),
cast_config_grad_output=Float8TensorCastConfig(
scaling_type=scaling_type_grad_output
),
)

# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
Expand Down Expand Up @@ -136,9 +148,7 @@ def main(
linear_float8 = Float8Linear.from_float(
copy.deepcopy(linear_ref),
emulate=False,
scaling_type_input=scaling_type_input,
scaling_type_weight=scaling_type_weight,
scaling_type_grad_output=scaling_type_grad_output,
config=config,
)
scaling_repr = linear_float8.scaling_repr()

Expand All @@ -153,9 +163,7 @@ def main(
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()

def float8_forw_backward():
if linear_requires_sync(
scaling_type_input, scaling_type_weight, scaling_type_grad_output
):
if linear_requires_sync(config):
sync_float8_amax_and_scale_history(linear_float8)
linear_float8(input_tensor).sum().backward()

Expand Down
18 changes: 14 additions & 4 deletions benchmarks/bench_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
import torch.multiprocessing as mp
import torch.nn as nn
import torch.utils.benchmark as benchmark
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.config import (
Float8LinearConfig,
Float8TensorCastConfig,
TensorScalingType,
)
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
Expand All @@ -28,6 +32,14 @@
B, M, K, N = 32, 1024, 1024, 1024
lr = 0.01

config = Float8LinearConfig(
cast_config_input=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED),
cast_config_weight=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED),
cast_config_grad_output=Float8TensorCastConfig(
scaling_type=TensorScalingType.DELAYED
),
)


def benchmark_torch_function_in_microseconds(
func: Callable,
Expand Down Expand Up @@ -68,9 +80,7 @@ def get_model(K, N, is_fp8, base_dtype=torch.float32):
swap_linear_with_float8_linear(
m,
emulate=False,
scaling_type_input=TensorScalingType.DELAYED,
scaling_type_weight=TensorScalingType.DELAYED,
scaling_type_grad_output=TensorScalingType.DELAYED,
config=config,
)
return m

Expand Down
25 changes: 14 additions & 11 deletions benchmarks/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.config import (
Float8LinearConfig,
Float8TensorCastConfig,
TensorScalingType,
)
from float8_experimental.float8_linear_utils import (
linear_requires_sync,
swap_linear_with_float8_linear,
Expand Down Expand Up @@ -216,6 +220,13 @@ def main(
scaling_type_input = TensorScalingType(scaling_type_input)
scaling_type_weight = TensorScalingType(scaling_type_weight)
scaling_type_grad_output = TensorScalingType(scaling_type_grad_output)
config = Float8LinearConfig(
cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input),
cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight),
cast_config_grad_output=Float8TensorCastConfig(
scaling_type=scaling_type_grad_output
),
)
scaling_repr = "_".join(
[
s.short_str()
Expand Down Expand Up @@ -256,14 +267,8 @@ def main(

m_ref = m_ref.to(device).to(ref_dtype)

extra_kwargs = {
"scaling_type_input": scaling_type_input,
"scaling_type_weight": scaling_type_weight,
"scaling_type_grad_output": scaling_type_grad_output,
}

m_float8 = copy.deepcopy(m_ref)
swap_linear_with_float8_linear(m_float8, **extra_kwargs)
swap_linear_with_float8_linear(m_float8, config=config)

def ref_forw_backward(x):
out = m_ref(x)
Expand All @@ -281,9 +286,7 @@ def float8_forw_backward_wrapper(x):
# inspection of the fw+bw torch.compile without the scale
# syncing code
# TODO(future): make this better
if linear_requires_sync(
scaling_type_input, scaling_type_weight, scaling_type_grad_output
):
if linear_requires_sync(config):
with record_function("scale_amax_and_scales"):
sync_amax_history(m_float8)
out = float8_forw(x)
Expand Down
8 changes: 7 additions & 1 deletion float8_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
# Lets define a few top level things here
from float8_experimental.config import Float8LinearConfig
from float8_experimental.config import (
Float8LinearConfig,
Float8TensorCastConfig,
TensorScalingType,
)
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from float8_experimental.float8_tensor import (
Expand All @@ -21,7 +25,9 @@

__all__ = [
# configuration
"TensorScalingType",
"Float8LinearConfig",
"Float8TensorCastConfig",
# top level UX
"swap_linear_with_float8_linear",
# TODO(future): remove Float8Tensor and Float8Linear from public API
Expand Down
33 changes: 33 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,49 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import enum
from dataclasses import dataclass


class TensorScalingType(enum.Enum):
DELAYED = "delayed"
DYNAMIC = "dynamic"

def short_str(self):
if self is TensorScalingType.DELAYED:
return "del"
else:
assert self is TensorScalingType.DYNAMIC
return "dyn"


@dataclass(frozen=True)
class Float8TensorCastConfig:
"""
Configuration for casting a single tensor to float8
"""

scaling_type: TensorScalingType = TensorScalingType.DYNAMIC


@dataclass(frozen=True)
class Float8LinearConfig:
"""
Configuration for converting a `torch.nn.Linear` module to float8
for training.
"""

#
# Per-tensor configuration for `input`, `weight`, `grad_output`
#
cast_config_input: Float8TensorCastConfig = Float8TensorCastConfig()
cast_config_weight: Float8TensorCastConfig = Float8TensorCastConfig()
cast_config_grad_output: Float8TensorCastConfig = Float8TensorCastConfig()

#
# Per-linear configuration
#

# If True, on the first iteration of Float8Linear the amaxes will be
# initialized with the incoming data. As of 2023-12-30, this doesn't work
# with autocast + torch.compile + FSDP. Enabling this option is nice for
Expand Down
42 changes: 8 additions & 34 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import torch

from float8_experimental.config import Float8LinearConfig
from float8_experimental.config import Float8LinearConfig, TensorScalingType

from float8_experimental.float8_dynamic_utils import (
cast_to_float8_e4m3_dynamic,
Expand Down Expand Up @@ -148,18 +148,6 @@ def __init__(self, history_len: int = 16, scale_fn_name: str = "max"):
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."


class TensorScalingType(enum.Enum):
DELAYED = "delayed"
DYNAMIC = "dynamic"

def short_str(self):
if self is TensorScalingType.DELAYED:
return "del"
else:
assert self is TensorScalingType.DYNAMIC
return "dyn"


class Float8Linear(torch.nn.Linear):
"""
A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks
Expand All @@ -170,9 +158,6 @@ def __init__(self, *args, **kwargs):
"""
Additional arguments on top of `torch.nn.Linear`'s arguments:
* `delayed_scaling_recipe`: configuration for delayed scaling
* `scaling_type_input`: delayed vs dynamic scaling for `input`
* `scaling_type_weight`: delayed vs dynamic scaling for `weight`
* `scaling_type_grad_output`: delayed vs dynamic scaling for `grad_output`
* `config`: Float8LinearConfig
"""

Expand All @@ -182,20 +167,13 @@ def __init__(self, *args, **kwargs):
# Amax scales should always be kept as float32.
self.always_float32_buffers = set()
emulate = kwargs.pop("emulate", False)
scaling_type_input = kwargs.pop("scaling_type_input", TensorScalingType.DYNAMIC)
scaling_type_weight = kwargs.pop(
"scaling_type_weight", TensorScalingType.DYNAMIC
)
scaling_type_grad_output = kwargs.pop(
"scaling_type_grad_output", TensorScalingType.DYNAMIC
)
config = kwargs.pop("config")
super().__init__(*args, **kwargs)

# Defines the scaling behavior of input, weight, grad_output
self.scaling_type_input = scaling_type_input
self.scaling_type_weight = scaling_type_weight
self.scaling_type_grad_output = scaling_type_grad_output
self.scaling_type_input = config.cast_config_input.scaling_type
self.scaling_type_weight = config.cast_config_weight.scaling_type
self.scaling_type_grad_output = config.cast_config_grad_output.scaling_type
# Convenience flag to skip code related to delayed scaling
self.has_any_delayed_scaling = (
self.scaling_type_input is TensorScalingType.DELAYED
Expand Down Expand Up @@ -457,9 +435,6 @@ def from_float(
cls,
mod,
emulate: bool = False,
scaling_type_input=TensorScalingType.DYNAMIC,
scaling_type_weight=TensorScalingType.DYNAMIC,
scaling_type_grad_output=TensorScalingType.DYNAMIC,
config: Optional[Float8LinearConfig] = None,
):
"""
Expand All @@ -477,9 +452,6 @@ def from_float(
mod.in_features,
mod.out_features,
bias=False,
scaling_type_input=scaling_type_input,
scaling_type_weight=scaling_type_weight,
scaling_type_grad_output=scaling_type_grad_output,
emulate=emulate,
config=config,
)
Expand All @@ -495,15 +467,17 @@ def from_float(
# 2. buffers need to be already created for the delayed scaling version
# of the weight wrapper to be initialized
if config.enable_fsdp_fp8_all_gather:
if scaling_type_weight is TensorScalingType.DYNAMIC:
if config.cast_config_weight.scaling_type is TensorScalingType.DYNAMIC:
new_mod.weight = torch.nn.Parameter(
WeightWithDynamicFloat8CastTensor(
new_mod.weight,
new_mod.linear_mm_config,
)
)
else:
assert scaling_type_weight is TensorScalingType.DELAYED
assert (
config.cast_config_weight.scaling_type is TensorScalingType.DELAYED
)
new_mod.weight = torch.nn.Parameter(
WeightWithDelayedFloat8CastTensor(
new_mod.weight,
Expand Down
Loading

0 comments on commit f64d339

Please sign in to comment.