From 73fd16885200027278f70078098c4130d79916b5 Mon Sep 17 00:00:00 2001 From: willfengg Date: Tue, 9 Jul 2024 18:44:51 -0700 Subject: [PATCH 1/2] fix linter error in CI (#313) Summary: `pre-commit run --all-files` complains about linter error from trunk. fix it in this PR Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/313 Reviewed By: drisspg Differential Revision: D59562565 Pulled By: weifengpy fbshipit-source-id: b276413d2a6b25632690d59ea8d4b3f5b680a66a --- README.md | 2 +- benchmarks/bench_multi_gpu.py | 2 +- benchmarks/profile_linear_float8.py | 2 +- float8_experimental/float8_dynamic_utils.py | 3 --- float8_experimental/float8_linear_utils.py | 4 +--- test/test_dtensor.py | 2 +- test/test_fsdp.py | 4 ++-- test/test_fsdp2/test_fsdp2_common.py | 3 +-- test/test_fsdp2/test_fsdp2_eager.py | 3 +-- test/test_fsdp_compile.py | 2 +- test/test_inference_flows.py | 2 +- test/test_numerics_integration.py | 2 +- 12 files changed, 12 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index fa093c3..ff19b93 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ m = Model(...) # convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling # type swap_linear_with_float8_linear( - m, + m, Float8Linear, scaling_type_x=TensorScalingType.DELAYED, scaling_type_w=TensorScalingType.DELAYED, diff --git a/benchmarks/bench_multi_gpu.py b/benchmarks/bench_multi_gpu.py index 12a1ddb..00a549c 100644 --- a/benchmarks/bench_multi_gpu.py +++ b/benchmarks/bench_multi_gpu.py @@ -14,7 +14,7 @@ import torch.multiprocessing as mp import torch.nn as nn import torch.utils.benchmark as benchmark -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, diff --git a/benchmarks/profile_linear_float8.py b/benchmarks/profile_linear_float8.py index 1ef5478..503a01a 100644 --- a/benchmarks/profile_linear_float8.py +++ b/benchmarks/profile_linear_float8.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( linear_requires_sync, swap_linear_with_float8_linear, diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index f48424c..7f44363 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -9,10 +9,7 @@ from typing import Any, Optional, Tuple -import float8_experimental.config as config - import torch -import torch.nn as nn import torch.utils._pytree as pytree from float8_experimental.float8_tensor import ( diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 945f7a6..5d49e65 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -3,10 +3,8 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import copy import logging -from enum import auto, Enum -from typing import Callable, List, Optional, Type, Union +from typing import Callable, List, Optional import torch import torch.distributed as dist diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 2088b78..8aada4b 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -15,7 +15,7 @@ import torch.nn.functional as F from float8_experimental.float8_dynamic_utils import NoopFwToFloat8E5M2Bw -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig from float8_experimental.float8_tensor_parallel import ( diff --git a/test/test_fsdp.py b/test/test_fsdp.py index 79bba19..48b28da 100644 --- a/test/test_fsdp.py +++ b/test/test_fsdp.py @@ -21,7 +21,7 @@ import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( linear_requires_sync, swap_linear_with_float8_linear, @@ -149,7 +149,7 @@ def forward_backward(model, optim, is_fp8, i): model_fp8 = torch.compile(model_fp8) y_local = forward_backward(model, optimizer, is_fp8=False, i=i) y_local_fp8 = forward_backward(model_fp8, optimizer_fp8, is_fp8=True, i=i) - local_sqnr = compute_error(y_local, y_local_fp8) + local_sqnr = compute_error(y_local, y_local_fp8) # noqa: F841 # get global y y_global = [ diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index c20e8cc..9d42b56 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -1,12 +1,11 @@ import contextlib -from typing import List, Type +from typing import List import float8_experimental.config as config import torch import torch.distributed as dist import torch.nn as nn -from float8_experimental.float8_linear import Float8Linear def check_parity_no_mp( diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index 57123cd..5ca483f 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -1,5 +1,4 @@ import copy -import itertools import threading import unittest from typing import Any, List @@ -9,7 +8,7 @@ import torch.distributed as dist import torch.nn as nn from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear from test_fsdp2_common import ( check_parity_bf16_mp, diff --git a/test/test_fsdp_compile.py b/test/test_fsdp_compile.py index 715db29..3f1b5dc 100644 --- a/test/test_fsdp_compile.py +++ b/test/test_fsdp_compile.py @@ -18,7 +18,7 @@ import torch.multiprocessing as mp import torch.nn as nn from float8_experimental import config -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, diff --git a/test/test_inference_flows.py b/test/test_inference_flows.py index 1dd09d9..55543ae 100644 --- a/test/test_inference_flows.py +++ b/test/test_inference_flows.py @@ -13,7 +13,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear from float8_experimental.float8_tensor import Float8Tensor from float8_experimental.float8_utils import compute_error diff --git a/test/test_numerics_integration.py b/test/test_numerics_integration.py index 401d0fd..845c9ea 100644 --- a/test/test_numerics_integration.py +++ b/test/test_numerics_integration.py @@ -14,7 +14,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( linear_requires_sync, swap_linear_with_float8_linear, From 6cba2aeade7f2500d7b32c8e38106847201d7feb Mon Sep 17 00:00:00 2001 From: willfengg Date: Thu, 11 Jul 2024 17:22:04 -0700 Subject: [PATCH 2/2] precompute scale after optimizer.step for dynamic scaling (#266) Summary: Goal: improve float8 all-gather perf in FSDP2 by precomputing scales for all float8 params with a single all-reduce updated README for API usage: call `precompute_float8_scale_for_fsdp` inside the training loop after optimizer step ``` from float8_experimental.fsdp_utils import precompute_float8_amax_for_fsdp # inside the training loop model(input).sum().backward() optim.step() precompute_float8_scale_for_fsdp(model) ``` unit test `pytest -s test/test_fsdp2/test_fsdp2_eager.py -k test_transformer_parity_dynamic` **FSDP pre-forward**: shortend from 3ms to 1.8ms because of doing 1 all-reduce instead N small all-reduces Screenshot 2024-05-30 at 12 38 24 AM Screenshot 2024-05-30 at 12 48 14 AM **Pre-computing amax**: shortened from 5ms to 1.7ms, by switching from `torch._foreach_abs` + `torch.max(a)` to `torch._foreach_norm(weights, ord=math.inf)` Screenshot 2024-05-30 at 12 50 17 AM Screenshot 2024-05-30 at 12 49 54 AM Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/266 Reviewed By: vkuzo Differential Revision: D59562409 Pulled By: weifengpy fbshipit-source-id: 683c4719e20f6b30f39ca9109ee29e53981a2aec --- README.md | 14 +++++- float8_experimental/float8_dynamic_utils.py | 43 ++++++++++++++--- float8_experimental/fsdp_utils.py | 52 +++++++++++++++++++++ test/test_fsdp2/test_fsdp2_common.py | 4 ++ test/test_fsdp2/test_fsdp2_eager.py | 21 +++++++-- 5 files changed, 122 insertions(+), 12 deletions(-) create mode 100644 float8_experimental/fsdp_utils.py diff --git a/README.md b/README.md index ff19b93..464e9b1 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ This is the most accurate recipe as every tensor is scaled dynamically. from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, ) +from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp from float8_experimental.float8_linear import Float8Linear # create model @@ -51,7 +52,18 @@ model = FSDP(model, use_orig_params=True) # optional: enable torch.compile for improved performance m = torch.compile(m) -# train/finetune (not shown) +# toy training loop +for _ in range(N_ITER): + optimizer.zero_grad() + y = m(x) + y.sum().backward() + optimizer.step() + + # specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on + # this method is optional but is highly recommended for performance + # it calcuclates scales for all parameters in a single all-reduce + precompute_float8_dynamic_scale_for_fsdp(model) + ``` ## float8 linear with delayed scaling diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index 7f44363..b355098 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -82,7 +82,12 @@ def cast_to_float8_e5m2_dynamic_bw( class WeightWithDynamicFloat8CastTensor(torch.Tensor): @staticmethod - def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): + def __new__( + cls, + tensor: torch.Tensor, + mm_config: ScaledMMConfig, + precomputed_scale: Optional[torch.Tensor] = None, + ): return torch.Tensor._make_wrapper_subclass( cls, tensor.size(), @@ -96,9 +101,18 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): requires_grad=tensor.requires_grad, ) - def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig): + def __init__( + self, + tensor: torch.Tensor, + mm_config: ScaledMMConfig, + precomputed_scale: Optional[torch.Tensor] = None, + ): self._tensor = tensor self._mm_config = mm_config + # for dynamic scaling + # `precompute_float8_dynamic_scale_for_fsdp` calculates scales + # for all float8 parameters after optimizer step + self._precomputed_scale = precomputed_scale @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -127,20 +141,35 @@ def unwrap(t): ) def __tensor_flatten__(self): - return ["_tensor"], self._mm_config + if self._precomputed_scale: + return ["_tensor", "_precomputed_scale"], self._mm_config + else: + return ["_tensor"], self._mm_config @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): mm_config = flatten_spec - return WeightWithDynamicFloat8CastTensor(inner_tensors["_tensor"], mm_config) + return WeightWithDynamicFloat8CastTensor( + inner_tensors["_tensor"], + mm_config, + getattr(inner_tensors, "_precomputed_scale", None), + ) def __repr__(self): return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" def fsdp_pre_all_gather(self, mesh): - float8_tensor = cast_to_float8_e4m3_dynamic( - self._tensor, self._mm_config, reduce_amax=True - ) + if self._precomputed_scale is not None: + float8_tensor = Float8Tensor.to_float8( + self._tensor, + self._precomputed_scale, + torch.float8_e4m3fn, + mm_config=self._mm_config, + ) + else: + float8_tensor = cast_to_float8_e4m3_dynamic( + self._tensor, self._mm_config, reduce_amax=True + ) return (float8_tensor._data,), (float8_tensor._scale,) def fsdp_post_all_gather( diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py new file mode 100644 index 0000000..0ade173 --- /dev/null +++ b/float8_experimental/fsdp_utils.py @@ -0,0 +1,52 @@ +import math +from typing import List + +import torch +import torch.nn as nn +from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor +from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_utils import EPS + + +@torch.no_grad() +def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: + """ + Calculate scale dynamically for all float8 parameters. + This should be run after the optimizer step. It performs a single all-reduce to compute the + scales for all float8 weights. + Example usage: + model(input).sum().backward() + optim.step() + precompute_float8_dynamic_scale_for_fsdp(model) + """ + from torch.distributed._tensor import DTensor + + if any( + isinstance(m, Float8Linear) and m.scaling_type_w is TensorScalingType.DELAYED + for m in module.modules() + ): + raise NotImplementedError("Only supports delayed scaling") + float8_linears: List[Float8Linear] = [ + m + for m in module.modules() + if isinstance(m, Float8Linear) + and isinstance(m.weight, DTensor) + and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) + ] + weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] + + if not weights: + return + + # inf-norm is equivalent to max(abs(w)) + max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial + amax_tensor = torch.vstack(max_weights) # Partial + # clamp is dispatched through DTensor + # it will issue a single all-reduce + amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate + scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate + if amax_tensor.dtype is torch.float16: + scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) + scales = torch.split(scale_tensor, 1) # Replicate + for scale, float8_linear in zip(scales, float8_linears): + float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index 9d42b56..af57871 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -6,6 +6,7 @@ import torch import torch.distributed as dist import torch.nn as nn +from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp def check_parity_no_mp( @@ -15,6 +16,7 @@ def check_parity_no_mp( fsdp_model: nn.Module, fsdp_optim: torch.optim.Optimizer, local_inp: torch.Tensor, + precompute: bool = False, ): for iter_idx in range(10): losses: List[torch.Tensor] = [] @@ -28,6 +30,8 @@ def check_parity_no_mp( param.grad.div_(dist.get_world_size()) # TODO(future): add amax syncing once delayed scaling is supported optim.step() + if model is fsdp_model and precompute: + precompute_float8_dynamic_scale_for_fsdp(model) test_cls.assertEqual(losses[0], losses[1]) diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index 5ca483f..bdbc878 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -86,10 +86,21 @@ def world_size(self) -> int: @skip_if_lt_x_gpu(2) def test_transformer_parity_dynamic(self): - for enable_fsdp_fp8_all_gather in [False, True]: - self._test_transformer_parity_dynamic(enable_fsdp_fp8_all_gather) + self.run_subtests( + { + "enable_fsdp_fp8_all_gather": [False, True], + "precompute": [False, True], + }, + self._test_transformer_parity_dynamic, + ) - def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool): + def _test_transformer_parity_dynamic( + self, + enable_fsdp_fp8_all_gather: bool, + precompute: bool, + ): + if not enable_fsdp_fp8_all_gather and precompute: + return # NOTE: Weight-tying does not compose with fp8 all-gather because the # embedding weight and output linear weight are tied but only the # latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to @@ -109,7 +120,9 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool): local_inp = torch.randint( 0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda" ) - check_parity_no_mp(self, ref_module, ref_optim, module, optim, local_inp) + check_parity_no_mp( + self, ref_module, ref_optim, module, optim, local_inp, precompute + ) @skip_if_lt_x_gpu(2) def test_transformer_memory(self):