diff --git a/README.md b/README.md index fa093c3..464e9b1 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ This is the most accurate recipe as every tensor is scaled dynamically. from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, ) +from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp from float8_experimental.float8_linear import Float8Linear # create model @@ -51,7 +52,18 @@ model = FSDP(model, use_orig_params=True) # optional: enable torch.compile for improved performance m = torch.compile(m) -# train/finetune (not shown) +# toy training loop +for _ in range(N_ITER): + optimizer.zero_grad() + y = m(x) + y.sum().backward() + optimizer.step() + + # specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on + # this method is optional but is highly recommended for performance + # it calcuclates scales for all parameters in a single all-reduce + precompute_float8_dynamic_scale_for_fsdp(model) + ``` ## float8 linear with delayed scaling @@ -71,7 +83,7 @@ m = Model(...) # convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling # type swap_linear_with_float8_linear( - m, + m, Float8Linear, scaling_type_x=TensorScalingType.DELAYED, scaling_type_w=TensorScalingType.DELAYED, diff --git a/benchmarks/bench_multi_gpu.py b/benchmarks/bench_multi_gpu.py index 12a1ddb..00a549c 100644 --- a/benchmarks/bench_multi_gpu.py +++ b/benchmarks/bench_multi_gpu.py @@ -14,7 +14,7 @@ import torch.multiprocessing as mp import torch.nn as nn import torch.utils.benchmark as benchmark -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, diff --git a/benchmarks/profile_linear_float8.py b/benchmarks/profile_linear_float8.py index 1ef5478..503a01a 100644 --- a/benchmarks/profile_linear_float8.py +++ b/benchmarks/profile_linear_float8.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( linear_requires_sync, swap_linear_with_float8_linear, diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index f48424c..b355098 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -9,10 +9,7 @@ from typing import Any, Optional, Tuple -import float8_experimental.config as config - import torch -import torch.nn as nn import torch.utils._pytree as pytree from float8_experimental.float8_tensor import ( @@ -85,7 +82,12 @@ def cast_to_float8_e5m2_dynamic_bw( class WeightWithDynamicFloat8CastTensor(torch.Tensor): @staticmethod - def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): + def __new__( + cls, + tensor: torch.Tensor, + mm_config: ScaledMMConfig, + precomputed_scale: Optional[torch.Tensor] = None, + ): return torch.Tensor._make_wrapper_subclass( cls, tensor.size(), @@ -99,9 +101,18 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): requires_grad=tensor.requires_grad, ) - def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig): + def __init__( + self, + tensor: torch.Tensor, + mm_config: ScaledMMConfig, + precomputed_scale: Optional[torch.Tensor] = None, + ): self._tensor = tensor self._mm_config = mm_config + # for dynamic scaling + # `precompute_float8_dynamic_scale_for_fsdp` calculates scales + # for all float8 parameters after optimizer step + self._precomputed_scale = precomputed_scale @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -130,20 +141,35 @@ def unwrap(t): ) def __tensor_flatten__(self): - return ["_tensor"], self._mm_config + if self._precomputed_scale: + return ["_tensor", "_precomputed_scale"], self._mm_config + else: + return ["_tensor"], self._mm_config @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): mm_config = flatten_spec - return WeightWithDynamicFloat8CastTensor(inner_tensors["_tensor"], mm_config) + return WeightWithDynamicFloat8CastTensor( + inner_tensors["_tensor"], + mm_config, + getattr(inner_tensors, "_precomputed_scale", None), + ) def __repr__(self): return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" def fsdp_pre_all_gather(self, mesh): - float8_tensor = cast_to_float8_e4m3_dynamic( - self._tensor, self._mm_config, reduce_amax=True - ) + if self._precomputed_scale is not None: + float8_tensor = Float8Tensor.to_float8( + self._tensor, + self._precomputed_scale, + torch.float8_e4m3fn, + mm_config=self._mm_config, + ) + else: + float8_tensor = cast_to_float8_e4m3_dynamic( + self._tensor, self._mm_config, reduce_amax=True + ) return (float8_tensor._data,), (float8_tensor._scale,) def fsdp_post_all_gather( diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 13b47a3..e7abe5c 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -3,10 +3,8 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import copy import logging -from enum import auto, Enum -from typing import Callable, List, Optional, Type, Union +from typing import Callable, List, Optional import torch import torch.distributed as dist diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py new file mode 100644 index 0000000..0ade173 --- /dev/null +++ b/float8_experimental/fsdp_utils.py @@ -0,0 +1,52 @@ +import math +from typing import List + +import torch +import torch.nn as nn +from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor +from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_utils import EPS + + +@torch.no_grad() +def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: + """ + Calculate scale dynamically for all float8 parameters. + This should be run after the optimizer step. It performs a single all-reduce to compute the + scales for all float8 weights. + Example usage: + model(input).sum().backward() + optim.step() + precompute_float8_dynamic_scale_for_fsdp(model) + """ + from torch.distributed._tensor import DTensor + + if any( + isinstance(m, Float8Linear) and m.scaling_type_w is TensorScalingType.DELAYED + for m in module.modules() + ): + raise NotImplementedError("Only supports delayed scaling") + float8_linears: List[Float8Linear] = [ + m + for m in module.modules() + if isinstance(m, Float8Linear) + and isinstance(m.weight, DTensor) + and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) + ] + weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] + + if not weights: + return + + # inf-norm is equivalent to max(abs(w)) + max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial + amax_tensor = torch.vstack(max_weights) # Partial + # clamp is dispatched through DTensor + # it will issue a single all-reduce + amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate + scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate + if amax_tensor.dtype is torch.float16: + scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) + scales = torch.split(scale_tensor, 1) # Replicate + for scale, float8_linear in zip(scales, float8_linears): + float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 2088b78..8aada4b 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -15,7 +15,7 @@ import torch.nn.functional as F from float8_experimental.float8_dynamic_utils import NoopFwToFloat8E5M2Bw -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig from float8_experimental.float8_tensor_parallel import ( diff --git a/test/test_fsdp.py b/test/test_fsdp.py index 79bba19..48b28da 100644 --- a/test/test_fsdp.py +++ b/test/test_fsdp.py @@ -21,7 +21,7 @@ import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( linear_requires_sync, swap_linear_with_float8_linear, @@ -149,7 +149,7 @@ def forward_backward(model, optim, is_fp8, i): model_fp8 = torch.compile(model_fp8) y_local = forward_backward(model, optimizer, is_fp8=False, i=i) y_local_fp8 = forward_backward(model_fp8, optimizer_fp8, is_fp8=True, i=i) - local_sqnr = compute_error(y_local, y_local_fp8) + local_sqnr = compute_error(y_local, y_local_fp8) # noqa: F841 # get global y y_global = [ diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index c20e8cc..af57871 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -1,12 +1,12 @@ import contextlib -from typing import List, Type +from typing import List import float8_experimental.config as config import torch import torch.distributed as dist import torch.nn as nn -from float8_experimental.float8_linear import Float8Linear +from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp def check_parity_no_mp( @@ -16,6 +16,7 @@ def check_parity_no_mp( fsdp_model: nn.Module, fsdp_optim: torch.optim.Optimizer, local_inp: torch.Tensor, + precompute: bool = False, ): for iter_idx in range(10): losses: List[torch.Tensor] = [] @@ -29,6 +30,8 @@ def check_parity_no_mp( param.grad.div_(dist.get_world_size()) # TODO(future): add amax syncing once delayed scaling is supported optim.step() + if model is fsdp_model and precompute: + precompute_float8_dynamic_scale_for_fsdp(model) test_cls.assertEqual(losses[0], losses[1]) diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index 57123cd..bdbc878 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -1,5 +1,4 @@ import copy -import itertools import threading import unittest from typing import Any, List @@ -9,7 +8,7 @@ import torch.distributed as dist import torch.nn as nn from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear from test_fsdp2_common import ( check_parity_bf16_mp, @@ -87,10 +86,21 @@ def world_size(self) -> int: @skip_if_lt_x_gpu(2) def test_transformer_parity_dynamic(self): - for enable_fsdp_fp8_all_gather in [False, True]: - self._test_transformer_parity_dynamic(enable_fsdp_fp8_all_gather) + self.run_subtests( + { + "enable_fsdp_fp8_all_gather": [False, True], + "precompute": [False, True], + }, + self._test_transformer_parity_dynamic, + ) - def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool): + def _test_transformer_parity_dynamic( + self, + enable_fsdp_fp8_all_gather: bool, + precompute: bool, + ): + if not enable_fsdp_fp8_all_gather and precompute: + return # NOTE: Weight-tying does not compose with fp8 all-gather because the # embedding weight and output linear weight are tied but only the # latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to @@ -110,7 +120,9 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool): local_inp = torch.randint( 0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda" ) - check_parity_no_mp(self, ref_module, ref_optim, module, optim, local_inp) + check_parity_no_mp( + self, ref_module, ref_optim, module, optim, local_inp, precompute + ) @skip_if_lt_x_gpu(2) def test_transformer_memory(self): diff --git a/test/test_fsdp_compile.py b/test/test_fsdp_compile.py index 715db29..3f1b5dc 100644 --- a/test/test_fsdp_compile.py +++ b/test/test_fsdp_compile.py @@ -18,7 +18,7 @@ import torch.multiprocessing as mp import torch.nn as nn from float8_experimental import config -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, diff --git a/test/test_inference_flows.py b/test/test_inference_flows.py index 1dd09d9..55543ae 100644 --- a/test/test_inference_flows.py +++ b/test/test_inference_flows.py @@ -13,7 +13,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear from float8_experimental.float8_tensor import Float8Tensor from float8_experimental.float8_utils import compute_error diff --git a/test/test_numerics_integration.py b/test/test_numerics_integration.py index 401d0fd..845c9ea 100644 --- a/test/test_numerics_integration.py +++ b/test/test_numerics_integration.py @@ -14,7 +14,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( linear_requires_sync, swap_linear_with_float8_linear,