diff --git a/README.md b/README.md index 047fe6d..b32dbec 100644 --- a/README.md +++ b/README.md @@ -137,7 +137,7 @@ pytest test/test_numerics_integration.py ./test/test_dtensor.sh # run integration tests on the FSDP2 integration -python test/test_fsdp2/test_fsdp2_eager.py +python test/test_fsdp2/test_fsdp2.py # run all of these tests ./test/test_everything.sh diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index 9ad76f7..215a394 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -4,8 +4,6 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Optional, Tuple - import torch from float8_experimental.float8_tensor import ( diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index 81d53b5..c7eb2c0 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -64,7 +64,9 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) scales = torch.split(scale_tensor, 1) # Replicate for scale, float8_linear in zip(scales, float8_linears): - float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor + float8_linear.weight._local_tensor._precomputed_scale = ( + scale._local_tensor.squeeze() + ) # FSDP pads its local tensor on dim-0. The subclass should be preserved such @@ -301,7 +303,7 @@ def __tensor_flatten__(self): ], { "mm_config": self._mm_config, - "is_amax_initialized": is_amax_initialized, + "is_amax_initialized": self.is_amax_initialized, }, ) diff --git a/test/test_everything.sh b/test/test_everything.sh index 5eeb17c..72ca42d 100755 --- a/test/test_everything.sh +++ b/test/test_everything.sh @@ -15,7 +15,7 @@ then ./test/test_fsdp.sh ./test/test_fsdp_compile.sh ./test/test_dtensor.sh -pytest test/test_fsdp2/test_fsdp2_eager.py +pytest test/test_fsdp2/test_fsdp2.py fi echo "all tests successful" diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2.py similarity index 96% rename from test/test_fsdp2/test_fsdp2_eager.py rename to test/test_fsdp2/test_fsdp2.py index 91c629f..1cbec77 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2.py @@ -89,6 +89,7 @@ def test_transformer_parity(self): TensorScalingType.DYNAMIC, TensorScalingType.DELAYED, ], + "compile_transformer_block": [False, True], }, self._test_transformer_parity, ) @@ -98,6 +99,7 @@ def _test_transformer_parity( enable_fsdp_fp8_all_gather: bool, precompute: bool, scaling_type_w: TensorScalingType, + compile_transformer_block: bool, ): if not enable_fsdp_fp8_all_gather and precompute: return @@ -112,11 +114,17 @@ def _test_transformer_parity( module = self.init_transformer(weight_tying=weight_tying).cuda() ref_module = copy.deepcopy(module) swap_linear_with_float8_linear(ref_module, scaling_type_w=scaling_type_w) + if compile_transformer_block: + for layer_id, transformer_block in ref_module.layers.named_children(): + transformer_block = torch.compile(transformer_block, dynamic=False) + ref_module.layers.register_module(layer_id, transformer_block) with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): swap_linear_with_float8_linear(module, scaling_type_w=scaling_type_w) - for submodule in module.modules(): - if isinstance(submodule, TransformerBlock): - fully_shard(submodule) + for layer_id, transformer_block in module.layers.named_children(): + if compile_transformer_block: + transformer_block = torch.compile(transformer_block, dynamic=False) + fully_shard(transformer_block) + module.layers.register_module(layer_id, transformer_block) fully_shard(module) ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2) optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True) @@ -132,6 +140,7 @@ def _test_transformer_parity( local_inp, precompute, scaling_type_w=scaling_type_w, + compile_transformer_block=compile_transformer_block, ) @skip_if_lt_x_gpu(2) diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index 2638401..61edac9 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -6,7 +6,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( linear_requires_sync, sync_float8_amax_and_scale_history, @@ -23,6 +23,7 @@ def check_parity_no_mp( local_inp: torch.Tensor, precompute: bool = False, scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC, + compile_transformer_block: bool = False, ): for iter_idx in range(10): losses: List[torch.Tensor] = [] @@ -46,7 +47,10 @@ def check_parity_no_mp( ): precompute_float8_dynamic_scale_for_fsdp(model) - test_cls.assertEqual(losses[0], losses[1]) + if compile_transformer_block: + test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4) + else: + test_cls.assertEqual(losses[0], losses[1]) def check_parity_bf16_mp(